[TUTORIAL] Resnet-18 end to end tutorial example (#55)
This commit is contained in:
Родитель
8539ac5810
Коммит
ffe1badd9d
|
@ -0,0 +1,326 @@
|
||||||
|
"""
|
||||||
|
ResNet Inference Example
|
||||||
|
========================
|
||||||
|
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
|
||||||
|
|
||||||
|
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
|
||||||
|
onto the VTA accelerator design to perform ImageNet classification tasks.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Import Libraries
|
||||||
|
# ----------------
|
||||||
|
# We start by importing the tvm, vta, nnvm libraries to run this example.
|
||||||
|
|
||||||
|
from __future__ import absolute_import, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import nnvm
|
||||||
|
import nnvm.compiler
|
||||||
|
import tvm
|
||||||
|
import vta
|
||||||
|
import vta.testing
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nnvm.compiler import graph_attr
|
||||||
|
from tvm.contrib import graph_runtime, rpc, util
|
||||||
|
from tvm.contrib.download import download
|
||||||
|
from vta.testing import simulator
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Load VTA parameters from the config.json file
|
||||||
|
env = vta.get_env()
|
||||||
|
|
||||||
|
# Helper to crop an image to a square (224, 224)
|
||||||
|
# Takes in an Image object, returns an Image object
|
||||||
|
def thumbnailify(image, pad=15):
|
||||||
|
w, h = image.size
|
||||||
|
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
|
||||||
|
image = image.crop(crop)
|
||||||
|
image = image.resize((224, 224))
|
||||||
|
return image
|
||||||
|
|
||||||
|
# Helper function to read in image
|
||||||
|
# Takes in Image object, returns an ND array
|
||||||
|
def process_image(image):
|
||||||
|
# Convert to neural network input format
|
||||||
|
image = np.array(image) - np.array([123., 117., 104.])
|
||||||
|
image /= np.array([58.395, 57.12, 57.375])
|
||||||
|
image = image.transpose((2, 0, 1))
|
||||||
|
image = image[np.newaxis, :]
|
||||||
|
|
||||||
|
return tvm.nd.array(image.astype("float32"))
|
||||||
|
|
||||||
|
# Classification helper function
|
||||||
|
# Takes in the graph runtime, and an image, and returns top result and time
|
||||||
|
def classify(m, image):
|
||||||
|
m.set_input('data', image)
|
||||||
|
timer = m.module.time_evaluator("run", ctx, number=1)
|
||||||
|
tcost = timer()
|
||||||
|
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
|
||||||
|
top = np.argmax(tvm_output.asnumpy())
|
||||||
|
tcost = "t={0:.2f}s".format(tcost.mean)
|
||||||
|
return tcost + " {}".format(synset[top])
|
||||||
|
|
||||||
|
# Helper function to compile the NNVM graph
|
||||||
|
# Takes in a path to a graph file, params file, and device target
|
||||||
|
# Returns the NNVM graph object, a compiled library object, and the params dict
|
||||||
|
def generate_graph(graph_fn, params_fn, device="vta"):
|
||||||
|
|
||||||
|
# Measure build start time
|
||||||
|
build_start = time.time()
|
||||||
|
|
||||||
|
# Derive the TVM target
|
||||||
|
target = tvm.target.create("llvm -device={}".format(device))
|
||||||
|
|
||||||
|
# Derive the LLVM compiler flags
|
||||||
|
# When targetting the Pynq, cross-compile to ARMv7 ISA
|
||||||
|
if env.TARGET == "sim":
|
||||||
|
target_host = "llvm"
|
||||||
|
elif env.TARGET == "pynq":
|
||||||
|
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
|
||||||
|
|
||||||
|
# Load the ResNet-18 graph and parameters
|
||||||
|
sym = nnvm.graph.load_json(open(graph_fn).read())
|
||||||
|
params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())
|
||||||
|
|
||||||
|
# Populate the shape and data type dictionary
|
||||||
|
shape_dict = {"data": (1, 3, 224, 224)}
|
||||||
|
dtype_dict = {"data": 'float32'}
|
||||||
|
shape_dict.update({k: v.shape for k, v in params.items()})
|
||||||
|
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
|
||||||
|
|
||||||
|
# Create NNVM graph
|
||||||
|
graph = nnvm.graph.create(sym)
|
||||||
|
graph_attr.set_shape_inputs(sym, shape_dict)
|
||||||
|
graph_attr.set_dtype_inputs(sym, dtype_dict)
|
||||||
|
graph = graph.apply("InferShape").apply("InferType")
|
||||||
|
|
||||||
|
# Apply NNVM graph optimization passes
|
||||||
|
sym = vta.graph.clean_cast(sym)
|
||||||
|
sym = vta.graph.clean_conv_fuse(sym)
|
||||||
|
if target.device_name == "vta":
|
||||||
|
assert env.BLOCK_IN == env.BLOCK_OUT
|
||||||
|
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
|
||||||
|
|
||||||
|
# Compile NNVM graph
|
||||||
|
with nnvm.compiler.build_config(opt_level=3):
|
||||||
|
if target.device_name != "vta":
|
||||||
|
graph, lib, params = nnvm.compiler.build(
|
||||||
|
sym, target_host, shape_dict, dtype_dict,
|
||||||
|
params=params)
|
||||||
|
else:
|
||||||
|
with vta.build_config():
|
||||||
|
graph, lib, params = nnvm.compiler.build(
|
||||||
|
sym, target, shape_dict, dtype_dict,
|
||||||
|
params=params, target_host=target_host)
|
||||||
|
|
||||||
|
# Save the compiled inference graph library
|
||||||
|
assert tvm.module.enabled("rpc")
|
||||||
|
temp = util.tempdir()
|
||||||
|
lib.save(temp.relpath("graphlib.o"))
|
||||||
|
|
||||||
|
# Send the inference library over to the remote RPC server
|
||||||
|
remote.upload(temp.relpath("graphlib.o"))
|
||||||
|
lib = remote.load_module("graphlib.o")
|
||||||
|
|
||||||
|
# Measure build time
|
||||||
|
build_time = time.time() - build_start
|
||||||
|
print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))
|
||||||
|
|
||||||
|
return graph, lib, params
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Download ResNet Model
|
||||||
|
# --------------------------------------------
|
||||||
|
# Download the necessary files to run ResNet-18.
|
||||||
|
#
|
||||||
|
|
||||||
|
# Obtain ResNet model and download them into _data dir
|
||||||
|
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
|
||||||
|
categ_fn = 'synset.txt'
|
||||||
|
graph_fn = 'resnet18_qt8.json'
|
||||||
|
params_fn = 'resnet18_qt8.params'
|
||||||
|
|
||||||
|
# Create data dir
|
||||||
|
data_dir = "_data/"
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
os.makedirs(data_dir)
|
||||||
|
|
||||||
|
# Download files
|
||||||
|
for file in [categ_fn, graph_fn, params_fn]:
|
||||||
|
if not os.path.isfile(file):
|
||||||
|
download(os.path.join(url, file), os.path.join(data_dir, file))
|
||||||
|
|
||||||
|
# Read in ImageNet Categories
|
||||||
|
synset = eval(open(os.path.join(data_dir, categ_fn)).read())
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Setup the Pynq Board's RPC Server
|
||||||
|
# ---------------------------------
|
||||||
|
# Build the RPC server's VTA runtime and program the Pynq FPGA.
|
||||||
|
|
||||||
|
# Measure build start time
|
||||||
|
reconfig_start = time.time()
|
||||||
|
|
||||||
|
# We read the Pynq RPC host IP address and port number from the OS environment
|
||||||
|
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
|
||||||
|
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
|
||||||
|
|
||||||
|
# We configure both the bitstream and the runtime system on the Pynq
|
||||||
|
# to match the VTA configuration specified by the config.json file.
|
||||||
|
if env.TARGET == "pynq":
|
||||||
|
|
||||||
|
# Make sure that TVM was compiled with RPC=1
|
||||||
|
assert tvm.module.enabled("rpc")
|
||||||
|
remote = rpc.connect(host, port)
|
||||||
|
|
||||||
|
# Reconfigure the JIT runtime
|
||||||
|
vta.reconfig_runtime(remote)
|
||||||
|
|
||||||
|
# Program the FPGA with a pre-compiled VTA bitstream.
|
||||||
|
# You can program the FPGA with your own custom bitstream
|
||||||
|
# by passing the path to the bitstream file instead of None.
|
||||||
|
vta.program_fpga(remote, bitstream=None)
|
||||||
|
|
||||||
|
# Report on reconfiguration time
|
||||||
|
reconfig_time = time.time() - reconfig_start
|
||||||
|
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
|
||||||
|
|
||||||
|
# In simulation mode, host the RPC server locally.
|
||||||
|
elif env.TARGET == "sim":
|
||||||
|
remote = rpc.LocalSession()
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Build the ResNet Runtime
|
||||||
|
# ------------------------
|
||||||
|
# Build the ResNet graph runtime, and configure the parameters.
|
||||||
|
|
||||||
|
# Set ``device=cpu`` to run inference on the CPU,
|
||||||
|
# or ``device=vtacpu`` to run inference on the FPGA.
|
||||||
|
device = "vta"
|
||||||
|
|
||||||
|
# Device context
|
||||||
|
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
|
||||||
|
|
||||||
|
# Build the graph runtime
|
||||||
|
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
|
||||||
|
os.path.join(data_dir, params_fn),
|
||||||
|
device)
|
||||||
|
m = graph_runtime.create(graph, lib, ctx)
|
||||||
|
|
||||||
|
# Set the parameters
|
||||||
|
m.set_input(**params)
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Run ResNet-18 inference on a sample image
|
||||||
|
# -----------------------------------------
|
||||||
|
# Perform image classification on test image.
|
||||||
|
# You can change the test image URL to any image of your choosing.
|
||||||
|
|
||||||
|
# Read in test image
|
||||||
|
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
|
||||||
|
# Read in test image
|
||||||
|
response = requests.get(image_url)
|
||||||
|
image = Image.open(BytesIO(response.content)).resize((224, 224))
|
||||||
|
# Show Image
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.show()
|
||||||
|
# Set the input
|
||||||
|
image = process_image(image)
|
||||||
|
m.set_input('data', image)
|
||||||
|
|
||||||
|
# Perform inference
|
||||||
|
timer = m.module.time_evaluator("run", ctx, number=1)
|
||||||
|
tcost = timer()
|
||||||
|
|
||||||
|
# Get classification results
|
||||||
|
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
|
||||||
|
top_categories = np.argsort(tvm_output.asnumpy())
|
||||||
|
|
||||||
|
# Report top-5 classification results
|
||||||
|
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
|
||||||
|
print(" #2:", synset[top_categories[-2]])
|
||||||
|
print(" #3:", synset[top_categories[-3]])
|
||||||
|
print(" #4:", synset[top_categories[-4]])
|
||||||
|
print(" #5:", synset[top_categories[-5]])
|
||||||
|
print("Performed inference in {0:.2f}s".format(tcost.mean))
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Run a Youtube Video Image Classifier
|
||||||
|
# ------------------------------------
|
||||||
|
# Perform image classification on test stream on 1 frame every 48 frames.
|
||||||
|
# Comment the `if False:` out to run the demo
|
||||||
|
|
||||||
|
# Early exit - remove for Demo
|
||||||
|
if False:
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import pafy
|
||||||
|
from IPython.display import clear_output
|
||||||
|
|
||||||
|
# Helper to crop an image to a square (224, 224)
|
||||||
|
# Takes in an Image object, returns an Image object
|
||||||
|
def thumbnailify(image, pad=15):
|
||||||
|
w, h = image.size
|
||||||
|
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
|
||||||
|
image = image.crop(crop)
|
||||||
|
image = image.resize((224, 224))
|
||||||
|
return image
|
||||||
|
|
||||||
|
# 16:16 inches
|
||||||
|
plt.rcParams['figure.figsize'] = [16, 16]
|
||||||
|
|
||||||
|
# Stream the video in
|
||||||
|
url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
|
||||||
|
video = pafy.new(url)
|
||||||
|
best = video.getbest(preftype="mp4")
|
||||||
|
cap = cv2.VideoCapture(best.url)
|
||||||
|
|
||||||
|
# Process one frame out of every 48 for variety
|
||||||
|
count = 0
|
||||||
|
guess = ""
|
||||||
|
while(count<2400):
|
||||||
|
|
||||||
|
# Capture frame-by-frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
# Process one every 48 frames
|
||||||
|
if count % 48 == 1:
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
# Crop and resize
|
||||||
|
thumb = np.array(thumbnailify(frame))
|
||||||
|
image = process_image(thumb)
|
||||||
|
guess = classify(m, image)
|
||||||
|
|
||||||
|
# Insert guess in frame
|
||||||
|
frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
|
||||||
|
cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)
|
||||||
|
|
||||||
|
plt.imshow(thumb)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
clear_output(wait=True)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# When everything done, release the capture
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
Загрузка…
Ссылка в новой задаче