[VTA] [Hardware] Chisel implementation (#3258)
This commit is contained in:
Родитель
1f62d9561c
Коммит
32f74f31c8
|
@ -132,9 +132,6 @@ set(USE_SORT ON)
|
|||
# Build ANTLR parser for Relay text format
|
||||
set(USE_ANTLR OFF)
|
||||
|
||||
# Build TSIM for VTA
|
||||
set(USE_VTA_TSIM OFF)
|
||||
|
||||
# Whether use Relay debug mode
|
||||
set(USE_RELAY_DEBUG OFF)
|
||||
|
||||
|
|
|
@ -29,8 +29,7 @@ elseif(PYTHON)
|
|||
--use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json)
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target)
|
||||
string(STRIP ${__vta_target} VTA_TARGET)
|
||||
execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
message(STATUS "Build VTA runtime with target: " ${VTA_TARGET})
|
||||
|
||||
|
@ -44,6 +43,13 @@ elseif(PYTHON)
|
|||
|
||||
add_library(vta SHARED ${VTA_RUNTIME_SRCS})
|
||||
|
||||
if(${VTA_TARGET} STREQUAL "tsim")
|
||||
target_compile_definitions(vta PUBLIC USE_TSIM)
|
||||
include_directories("vta/include")
|
||||
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
|
||||
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
|
||||
endif()
|
||||
|
||||
target_include_directories(vta PUBLIC vta/include)
|
||||
|
||||
foreach(__def ${VTA_DEFINITIONS})
|
||||
|
@ -61,12 +67,6 @@ elseif(PYTHON)
|
|||
target_link_libraries(vta ${__cma_lib})
|
||||
endif()
|
||||
|
||||
if(NOT USE_VTA_TSIM STREQUAL "OFF")
|
||||
include_directories("vta/include")
|
||||
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
|
||||
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(STATUS "Cannot found python in env, VTA build is skipped..")
|
||||
endif()
|
||||
|
|
|
@ -49,7 +49,7 @@ sudo apt install verilator sbt
|
|||
## Setup in TVM
|
||||
|
||||
1. Install `verilator` and `sbt` as described above
|
||||
2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
|
||||
2. Set the VTA TARGET to `tsim` on `<tvm-root>/vta/config/vta_config.json`
|
||||
3. Build tvm
|
||||
|
||||
## How to run VTA TSIM examples
|
||||
|
|
|
@ -124,7 +124,7 @@ else()
|
|||
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
|
||||
add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
|
||||
|
||||
set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
|
||||
set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
|
||||
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
|
||||
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
|
||||
else()
|
||||
|
|
|
@ -15,5 +15,81 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
CONFIG = DefaultF1Config
|
||||
TOP = VTA
|
||||
TOP_TEST = Test
|
||||
BUILD_NAME = build
|
||||
USE_TRACE = 0
|
||||
VTA_LIBNAME = libvta_hw
|
||||
|
||||
config_test = $(TOP_TEST)$(CONFIG)
|
||||
vta_dir = $(abspath ../../)
|
||||
tvm_dir = $(abspath ../../../)
|
||||
verilator_inc_dir = /usr/local/share/verilator/include
|
||||
verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator
|
||||
chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel
|
||||
|
||||
verilator_opt = --cc
|
||||
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
|
||||
verilator_opt += +define+RANDOMIZE_REG_INIT
|
||||
verilator_opt += +define+RANDOMIZE_MEM_INIT
|
||||
verilator_opt += --x-assign unique
|
||||
verilator_opt += --output-split 20000
|
||||
verilator_opt += --output-split-cfuncs 20000
|
||||
verilator_opt += --top-module ${TOP_TEST}
|
||||
verilator_opt += -Mdir ${verilator_build_dir}
|
||||
verilator_opt += -I$(chisel_build_dir)
|
||||
|
||||
cxx_flags = -O2 -Wall -fPIC -shared
|
||||
cxx_flags += -fvisibility=hidden -std=c++11
|
||||
cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
|
||||
cxx_flags += -DVL_PRINTF=printf
|
||||
cxx_flags += -DVL_USER_FINISH
|
||||
cxx_flags += -DVM_COVERAGE=0
|
||||
cxx_flags += -DVM_SC=0
|
||||
cxx_flags += -Wno-sign-compare
|
||||
cxx_flags += -include V$(TOP_TEST).h
|
||||
cxx_flags += -I$(verilator_build_dir)
|
||||
cxx_flags += -I$(verilator_inc_dir)
|
||||
cxx_flags += -I$(verilator_inc_dir)/vltstd
|
||||
cxx_flags += -I$(vta_dir)/include
|
||||
cxx_flags += -I$(tvm_dir)/include
|
||||
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
|
||||
|
||||
cxx_files = $(verilator_inc_dir)/verilated.cpp
|
||||
cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp
|
||||
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
|
||||
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
|
||||
|
||||
ifneq ($(USE_TRACE), 0)
|
||||
verilator_opt += --trace
|
||||
cxx_flags += -DVM_TRACE=1
|
||||
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
|
||||
cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp
|
||||
else
|
||||
cxx_flags += -DVM_TRACE=0
|
||||
endif
|
||||
|
||||
default: lib
|
||||
|
||||
lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
|
||||
$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp
|
||||
g++ $(cxx_flags) $(cxx_files) -o $@
|
||||
|
||||
verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
|
||||
$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
|
||||
verilator $(verilator_opt) $<
|
||||
|
||||
verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
|
||||
$(chisel_build_dir)/$(TOP).$(CONFIG).v:
|
||||
sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'
|
||||
|
||||
verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
|
||||
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
|
||||
sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'
|
||||
|
||||
clean:
|
||||
-rm -rf target project/target project/project
|
||||
|
||||
cleanall:
|
||||
-rm -rf $(vta_dir)/$(BUILD_NAME)
|
||||
|
|
|
@ -112,7 +112,7 @@ module VTAHostDPI #
|
|||
|
||||
always_ff @(posedge clock) begin
|
||||
if (__exit == 'd1) begin
|
||||
$display("[DONE] at cycle:%016d", cycles);
|
||||
$display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
|
||||
$finish;
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Compute.
|
||||
*
|
||||
* The compute unit is in charge of the following:
|
||||
* - Loading micro-ops from memory (loadUop module)
|
||||
* - Loading biases (acc) from memory (tensorAcc module)
|
||||
* - Compute ALU instructions (tensorAlu module)
|
||||
* - Compute GEMM instructions (tensorGemm module)
|
||||
*/
|
||||
class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val i_post = Vec(2, Input(Bool()))
|
||||
val o_post = Vec(2, Output(Bool()))
|
||||
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
|
||||
val uop_baddr = Input(UInt(mp.addrBits.W))
|
||||
val acc_baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_rd = Vec(2, new VMEReadMaster)
|
||||
val inp = new TensorMaster(tensorType = "inp")
|
||||
val wgt = new TensorMaster(tensorType = "wgt")
|
||||
val out = new TensorMaster(tensorType = "out")
|
||||
val finish = Output(Bool())
|
||||
})
|
||||
val sIdle :: sSync :: sExe :: Nil = Enum(3)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
|
||||
|
||||
val loadUop = Module(new LoadUop)
|
||||
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
|
||||
val tensorGemm = Module(new TensorGemm)
|
||||
val tensorAlu = Module(new TensorAlu)
|
||||
|
||||
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
|
||||
|
||||
// decode
|
||||
val dec = Module(new ComputeDecode)
|
||||
dec.io.inst := inst_q.io.deq.bits
|
||||
|
||||
val inst_type = Cat(dec.io.isFinish,
|
||||
dec.io.isAlu,
|
||||
dec.io.isGemm,
|
||||
dec.io.isLoadAcc,
|
||||
dec.io.isLoadUop).asUInt
|
||||
|
||||
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
|
||||
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
|
||||
val start = snext & sprev
|
||||
val done =
|
||||
MuxLookup(inst_type,
|
||||
false.B, // default
|
||||
Array(
|
||||
"h_01".U -> loadUop.io.done,
|
||||
"h_02".U -> tensorAcc.io.done,
|
||||
"h_04".U -> tensorGemm.io.done,
|
||||
"h_08".U -> tensorAlu.io.done,
|
||||
"h_10".U -> true.B // Finish
|
||||
)
|
||||
)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (start) {
|
||||
when (dec.io.isSync) {
|
||||
state := sSync
|
||||
} .elsewhen (inst_type.orR) {
|
||||
state := sExe
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sSync) {
|
||||
state := sIdle
|
||||
}
|
||||
is (sExe) {
|
||||
when (done) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// instructions
|
||||
inst_q.io.enq <> io.inst
|
||||
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
|
||||
|
||||
// uop
|
||||
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
|
||||
loadUop.io.inst := inst_q.io.deq.bits
|
||||
loadUop.io.baddr := io.uop_baddr
|
||||
io.vme_rd(0) <> loadUop.io.vme_rd
|
||||
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
|
||||
|
||||
// acc
|
||||
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
|
||||
tensorAcc.io.inst := inst_q.io.deq.bits
|
||||
tensorAcc.io.baddr := io.acc_baddr
|
||||
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
|
||||
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
|
||||
io.vme_rd(1) <> tensorAcc.io.vme_rd
|
||||
|
||||
// gemm
|
||||
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
|
||||
tensorGemm.io.inst := inst_q.io.deq.bits
|
||||
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
|
||||
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
|
||||
tensorGemm.io.inp <> io.inp
|
||||
tensorGemm.io.wgt <> io.wgt
|
||||
tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
|
||||
tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
|
||||
tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
|
||||
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
|
||||
|
||||
// alu
|
||||
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
|
||||
tensorAlu.io.inst := inst_q.io.deq.bits
|
||||
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
|
||||
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
|
||||
tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
|
||||
tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
|
||||
tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
|
||||
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
|
||||
|
||||
// out
|
||||
io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
|
||||
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
|
||||
|
||||
// semaphore
|
||||
s(0).io.spost := io.i_post(0)
|
||||
s(1).io.spost := io.i_post(1)
|
||||
s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
|
||||
s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
|
||||
io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
|
||||
io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
|
||||
|
||||
// finish
|
||||
io.finish := state === sExe & done & dec.io.isFinish
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
// start
|
||||
when (state === sIdle && start) {
|
||||
when (dec.io.isSync) {
|
||||
printf("[Compute] start sync\n")
|
||||
} .elsewhen (dec.io.isLoadUop) {
|
||||
printf("[Compute] start load uop\n")
|
||||
} .elsewhen (dec.io.isLoadAcc) {
|
||||
printf("[Compute] start load acc\n")
|
||||
} .elsewhen (dec.io.isGemm) {
|
||||
printf("[Compute] start gemm\n")
|
||||
} .elsewhen (dec.io.isAlu) {
|
||||
printf("[Compute] start alu\n")
|
||||
} .elsewhen (dec.io.isFinish) {
|
||||
printf("[Compute] start finish\n")
|
||||
}
|
||||
}
|
||||
// done
|
||||
when (state === sSync) {
|
||||
printf("[Compute] done sync\n")
|
||||
}
|
||||
when (state === sExe) {
|
||||
when (done) {
|
||||
when (dec.io.isLoadUop) {
|
||||
printf("[Compute] done load uop\n")
|
||||
} .elsewhen (dec.io.isLoadAcc) {
|
||||
printf("[Compute] done load acc\n")
|
||||
} .elsewhen (dec.io.isGemm) {
|
||||
printf("[Compute] done gemm\n")
|
||||
} .elsewhen (dec.io.isAlu) {
|
||||
printf("[Compute] done alu\n")
|
||||
} .elsewhen (dec.io.isFinish) {
|
||||
printf("[Compute] done finish\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import vta.util.config._
|
||||
|
||||
/** CoreConfig.
|
||||
*
|
||||
* This is one supported configuration for VTA. This file will
|
||||
* be eventually filled out with class configurations that can be
|
||||
* mixed/matched with Shell configurations for different backends.
|
||||
*/
|
||||
class CoreConfig extends Config((site, here, up) => {
|
||||
case CoreKey => CoreParams(
|
||||
batch = 1,
|
||||
blockOut = 16,
|
||||
blockIn = 16,
|
||||
inpBits = 8,
|
||||
wgtBits = 8,
|
||||
uopBits = 32,
|
||||
accBits = 32,
|
||||
outBits = 8,
|
||||
uopMemDepth = 2048,
|
||||
inpMemDepth = 2048,
|
||||
wgtMemDepth = 1024,
|
||||
accMemDepth = 2048,
|
||||
outMemDepth = 2048,
|
||||
instQueueEntries = 512)
|
||||
})
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Core parameters */
|
||||
case class CoreParams (
|
||||
batch: Int = 1,
|
||||
blockOut: Int = 16,
|
||||
blockIn: Int = 16,
|
||||
inpBits: Int = 8,
|
||||
wgtBits: Int = 8,
|
||||
uopBits: Int = 32,
|
||||
accBits: Int = 32,
|
||||
outBits: Int = 8,
|
||||
uopMemDepth: Int = 512,
|
||||
inpMemDepth: Int = 512,
|
||||
wgtMemDepth: Int = 512,
|
||||
accMemDepth: Int = 512,
|
||||
outMemDepth: Int = 512,
|
||||
instQueueEntries: Int = 32
|
||||
)
|
||||
|
||||
case object CoreKey extends Field[CoreParams]
|
||||
|
||||
/** Core.
|
||||
*
|
||||
* The core defines the current VTA architecture by connecting memory and
|
||||
* compute modules together such as load/store and compute. Most of the
|
||||
* connections in the core are bulk (<>), and we should try to keep it this
|
||||
* way, because it is easier to understand what is going on.
|
||||
*
|
||||
* Also, the core must be instantiated by a shell using the
|
||||
* VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
|
||||
* More info about these interfaces and modules can be found in the shell
|
||||
* directory.
|
||||
*/
|
||||
class Core(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val vcr = new VCRClient
|
||||
val vme = new VMEMaster
|
||||
})
|
||||
val fetch = Module(new Fetch)
|
||||
val load = Module(new Load)
|
||||
val compute = Module(new Compute)
|
||||
val store = Module(new Store)
|
||||
|
||||
// Read(rd) and write(wr) from/to memory (i.e. DRAM)
|
||||
io.vme.rd(0) <> fetch.io.vme_rd
|
||||
io.vme.rd(1) <> compute.io.vme_rd(0)
|
||||
io.vme.rd(2) <> load.io.vme_rd(0)
|
||||
io.vme.rd(3) <> load.io.vme_rd(1)
|
||||
io.vme.rd(4) <> compute.io.vme_rd(1)
|
||||
io.vme.wr(0) <> store.io.vme_wr
|
||||
|
||||
// Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs)
|
||||
fetch.io.launch := io.vcr.launch
|
||||
fetch.io.ins_baddr := io.vcr.ptrs(0)
|
||||
fetch.io.ins_count := io.vcr.vals(0)
|
||||
|
||||
// Load inputs and weights from memory (DRAM) into scratchpads (SRAMs)
|
||||
load.io.i_post := compute.io.o_post(0)
|
||||
load.io.inst <> fetch.io.inst.ld
|
||||
load.io.inp_baddr := io.vcr.ptrs(2)
|
||||
load.io.wgt_baddr := io.vcr.ptrs(3)
|
||||
|
||||
// The compute module performs the following:
|
||||
// - Load micro-ops (uops) and accumulations (acc)
|
||||
// - Compute dense and ALU instructions (tasks)
|
||||
compute.io.i_post(0) := load.io.o_post
|
||||
compute.io.i_post(1) := store.io.o_post
|
||||
compute.io.inst <> fetch.io.inst.co
|
||||
compute.io.uop_baddr := io.vcr.ptrs(1)
|
||||
compute.io.acc_baddr := io.vcr.ptrs(4)
|
||||
compute.io.inp <> load.io.inp
|
||||
compute.io.wgt <> load.io.wgt
|
||||
|
||||
// The store module performs the following:
|
||||
// - Writes results from compute into scratchpads (SRAMs)
|
||||
// - Store results from scratchpads (SRAMs) to memory (DRAM)
|
||||
store.io.i_post := compute.io.o_post(1)
|
||||
store.io.inst <> fetch.io.inst.st
|
||||
store.io.out_baddr := io.vcr.ptrs(5)
|
||||
store.io.out <> compute.io.out
|
||||
|
||||
// Finish instruction is executed and asserts the VCR finish flag
|
||||
val finish = RegNext(compute.io.finish)
|
||||
io.vcr.finish := finish
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
import ISA._
|
||||
|
||||
/** MemDecode.
|
||||
*
|
||||
* Decode memory instructions with a Bundle. This is similar to an union,
|
||||
* therefore order matters when declaring fields. These are the instructions
|
||||
* decoded with this bundle:
|
||||
* - LUOP
|
||||
* - LWGT
|
||||
* - LINP
|
||||
* - LACC
|
||||
* - SOUT
|
||||
*/
|
||||
class MemDecode extends Bundle {
|
||||
val xpad_1 = UInt(M_PAD_BITS.W)
|
||||
val xpad_0 = UInt(M_PAD_BITS.W)
|
||||
val ypad_1 = UInt(M_PAD_BITS.W)
|
||||
val ypad_0 = UInt(M_PAD_BITS.W)
|
||||
val xstride = UInt(M_STRIDE_BITS.W)
|
||||
val xsize = UInt(M_SIZE_BITS.W)
|
||||
val ysize = UInt(M_SIZE_BITS.W)
|
||||
val empty_0 = UInt(7.W) // derive this
|
||||
val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
|
||||
val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
|
||||
val id = UInt(M_ID_BITS.W)
|
||||
val push_next = Bool()
|
||||
val push_prev = Bool()
|
||||
val pop_next = Bool()
|
||||
val pop_prev = Bool()
|
||||
val op = UInt(OP_BITS.W)
|
||||
}
|
||||
|
||||
/** GemmDecode.
|
||||
*
|
||||
* Decode GEMM instruction with a Bundle. This is similar to an union,
|
||||
* therefore order matters when declaring fields.
|
||||
*/
|
||||
class GemmDecode extends Bundle {
|
||||
val wgt_1 = UInt(C_WIDX_BITS.W)
|
||||
val wgt_0 = UInt(C_WIDX_BITS.W)
|
||||
val inp_1 = UInt(C_IIDX_BITS.W)
|
||||
val inp_0 = UInt(C_IIDX_BITS.W)
|
||||
val acc_1 = UInt(C_AIDX_BITS.W)
|
||||
val acc_0 = UInt(C_AIDX_BITS.W)
|
||||
val empty_0 = Bool()
|
||||
val lp_1 = UInt(C_ITER_BITS.W)
|
||||
val lp_0 = UInt(C_ITER_BITS.W)
|
||||
val uop_end = UInt(C_UOP_END_BITS.W)
|
||||
val uop_begin = UInt(C_UOP_BGN_BITS.W)
|
||||
val reset = Bool()
|
||||
val push_next = Bool()
|
||||
val push_prev = Bool()
|
||||
val pop_next = Bool()
|
||||
val pop_prev = Bool()
|
||||
val op = UInt(OP_BITS.W)
|
||||
}
|
||||
|
||||
/** AluDecode.
|
||||
*
|
||||
* Decode ALU instructions with a Bundle. This is similar to an union,
|
||||
* therefore order matters when declaring fields. These are the instructions
|
||||
* decoded with this bundle:
|
||||
* - VMIN
|
||||
* - VMAX
|
||||
* - VADD
|
||||
* - VSHX
|
||||
*/
|
||||
class AluDecode extends Bundle {
|
||||
val empty_1 = Bool()
|
||||
val alu_imm = UInt(C_ALU_IMM_BITS.W)
|
||||
val alu_use_imm = Bool()
|
||||
val alu_op = UInt(C_ALU_DEC_BITS.W)
|
||||
val src_1 = UInt(C_IIDX_BITS.W)
|
||||
val src_0 = UInt(C_IIDX_BITS.W)
|
||||
val dst_1 = UInt(C_AIDX_BITS.W)
|
||||
val dst_0 = UInt(C_AIDX_BITS.W)
|
||||
val empty_0 = Bool()
|
||||
val lp_1 = UInt(C_ITER_BITS.W)
|
||||
val lp_0 = UInt(C_ITER_BITS.W)
|
||||
val uop_end = UInt(C_UOP_END_BITS.W)
|
||||
val uop_begin = UInt(C_UOP_BGN_BITS.W)
|
||||
val reset = Bool()
|
||||
val push_next = Bool()
|
||||
val push_prev = Bool()
|
||||
val pop_next = Bool()
|
||||
val pop_prev = Bool()
|
||||
val op = UInt(OP_BITS.W)
|
||||
}
|
||||
|
||||
/** UopDecode.
|
||||
*
|
||||
* Decode micro-ops (uops).
|
||||
*/
|
||||
class UopDecode extends Bundle {
|
||||
val u2 = UInt(10.W)
|
||||
val u1 = UInt(11.W)
|
||||
val u0 = UInt(11.W)
|
||||
}
|
||||
|
||||
/** FetchDecode.
|
||||
*
|
||||
* Partial decoding for dispatching instructions to Load, Compute, and Store.
|
||||
*/
|
||||
class FetchDecode extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val isLoad = Output(Bool())
|
||||
val isCompute = Output(Bool())
|
||||
val isStore = Output(Bool())
|
||||
})
|
||||
val csignals =
|
||||
ListLookup(io.inst,
|
||||
List(N, OP_X),
|
||||
Array(
|
||||
LUOP -> List(Y, OP_G),
|
||||
LWGT -> List(Y, OP_L),
|
||||
LINP -> List(Y, OP_L),
|
||||
LACC -> List(Y, OP_G),
|
||||
SOUT -> List(Y, OP_S),
|
||||
GEMM -> List(Y, OP_G),
|
||||
FNSH -> List(Y, OP_G),
|
||||
VMIN -> List(Y, OP_G),
|
||||
VMAX -> List(Y, OP_G),
|
||||
VADD -> List(Y, OP_G),
|
||||
VSHX -> List(Y, OP_G)
|
||||
)
|
||||
)
|
||||
|
||||
val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals
|
||||
|
||||
io.isLoad := cs_val_inst & cs_op_type === OP_L
|
||||
io.isCompute := cs_val_inst & cs_op_type === OP_G
|
||||
io.isStore := cs_val_inst & cs_op_type === OP_S
|
||||
}
|
||||
|
||||
/** LoadDecode.
|
||||
*
|
||||
* Decode dependencies, type and sync for Load module.
|
||||
*/
|
||||
class LoadDecode extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val push_next = Output(Bool())
|
||||
val pop_next = Output(Bool())
|
||||
val isInput = Output(Bool())
|
||||
val isWeight = Output(Bool())
|
||||
val isSync = Output(Bool())
|
||||
})
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
io.push_next := dec.push_next
|
||||
io.pop_next := dec.pop_next
|
||||
io.isInput := io.inst === LINP & dec.xsize =/= 0.U
|
||||
io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U
|
||||
io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U
|
||||
}
|
||||
|
||||
/** ComputeDecode.
|
||||
*
|
||||
* Decode dependencies, type and sync for Compute module.
|
||||
*/
|
||||
class ComputeDecode extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val push_next = Output(Bool())
|
||||
val push_prev = Output(Bool())
|
||||
val pop_next = Output(Bool())
|
||||
val pop_prev = Output(Bool())
|
||||
val isLoadAcc = Output(Bool())
|
||||
val isLoadUop = Output(Bool())
|
||||
val isSync = Output(Bool())
|
||||
val isAlu = Output(Bool())
|
||||
val isGemm = Output(Bool())
|
||||
val isFinish = Output(Bool())
|
||||
})
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
io.push_next := dec.push_next
|
||||
io.push_prev := dec.push_prev
|
||||
io.pop_next := dec.pop_next
|
||||
io.pop_prev := dec.pop_prev
|
||||
io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U
|
||||
io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U
|
||||
io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U
|
||||
io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX
|
||||
io.isGemm := io.inst === GEMM
|
||||
io.isFinish := io.inst === FNSH
|
||||
}
|
||||
|
||||
/** StoreDecode.
|
||||
*
|
||||
* Decode dependencies, type and sync for Store module.
|
||||
*/
|
||||
class StoreDecode extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val push_prev = Output(Bool())
|
||||
val pop_prev = Output(Bool())
|
||||
val isStore = Output(Bool())
|
||||
val isSync = Output(Bool())
|
||||
})
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
io.push_prev := dec.push_prev
|
||||
io.pop_prev := dec.pop_prev
|
||||
io.isStore := io.inst === SOUT & dec.xsize =/= 0.U
|
||||
io.isSync := io.inst === SOUT & dec.xsize === 0.U
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Fetch.
|
||||
*
|
||||
* The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
|
||||
* VTA Memory Engine (VME), and push them into an instruction queue called
|
||||
* inst_q. Once the instruction queue is full, instructions are dispatched to
|
||||
* the Load, Compute and Store module queues based on the instruction opcode.
|
||||
* After draining the queue, the fetch unit checks if there are more instructions
|
||||
* via the ins_count register which is written by the host.
|
||||
*
|
||||
* Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
|
||||
* because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
|
||||
* This should be configurable for larger payloads, i.e. 64-bytes, which can load
|
||||
* more than one instruction at the time. Finally, the instruction queue is
|
||||
* sized (entries_q), depending on the maximum burst allowed in the memory.
|
||||
*/
|
||||
class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val vp = p(ShellKey).vcrParams
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val launch = Input(Bool())
|
||||
val ins_baddr = Input(UInt(mp.addrBits.W))
|
||||
val ins_count = Input(UInt(vp.regBits.W))
|
||||
val vme_rd = new VMEReadMaster
|
||||
val inst = new Bundle {
|
||||
val ld = Decoupled(UInt(INST_BITS.W))
|
||||
val co = Decoupled(UInt(INST_BITS.W))
|
||||
val st = Decoupled(UInt(INST_BITS.W))
|
||||
}
|
||||
})
|
||||
val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
|
||||
val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
|
||||
val dec = Module(new FetchDecode)
|
||||
|
||||
val s1_launch = RegNext(io.launch)
|
||||
val pulse = io.launch & ~s1_launch
|
||||
|
||||
val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
|
||||
val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
|
||||
val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
|
||||
|
||||
val xrem = Reg(chiselTypeOf(io.ins_count))
|
||||
val xsize = (io.ins_count << 1.U) - 1.U
|
||||
val xmax = (1 << mp.lenBits).U
|
||||
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
|
||||
|
||||
val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (pulse) {
|
||||
state := sReadCmd
|
||||
when (xsize < xmax) {
|
||||
rlen := xsize
|
||||
ilen := xsize >> 1.U
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
rlen := xmax - 1.U
|
||||
ilen := (xmax >> 1.U) - 1.U
|
||||
xrem := xsize - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sReadCmd) {
|
||||
when (io.vme_rd.cmd.ready) {
|
||||
state := sReadLSB
|
||||
}
|
||||
}
|
||||
is (sReadLSB) {
|
||||
when (io.vme_rd.data.valid) {
|
||||
state := sReadMSB
|
||||
}
|
||||
}
|
||||
is (sReadMSB) {
|
||||
when (io.vme_rd.data.valid) {
|
||||
when (inst_q.io.count === ilen) {
|
||||
state := sDrain
|
||||
} .otherwise {
|
||||
state := sReadLSB
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sDrain) {
|
||||
when (inst_q.io.count === 0.U) {
|
||||
when (xrem === 0.U) {
|
||||
state := sIdle
|
||||
} .elsewhen (xrem < xmax) {
|
||||
state := sReadCmd
|
||||
rlen := xrem
|
||||
ilen := xrem >> 1.U
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
rlen := xmax - 1.U
|
||||
ilen := (xmax >> 1.U) - 1.U
|
||||
xrem := xrem - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// read instructions from dram
|
||||
when (state === sIdle) {
|
||||
raddr := io.ins_baddr
|
||||
} .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
|
||||
raddr := raddr + xmax_bytes
|
||||
}
|
||||
|
||||
io.vme_rd.cmd.valid := state === sReadCmd
|
||||
io.vme_rd.cmd.bits.addr := raddr
|
||||
io.vme_rd.cmd.bits.len := rlen
|
||||
|
||||
io.vme_rd.data.ready := inst_q.io.enq.ready
|
||||
|
||||
val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
|
||||
val msb = io.vme_rd.data.bits
|
||||
val inst = Cat(msb, lsb)
|
||||
|
||||
when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
|
||||
|
||||
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
|
||||
inst_q.io.enq.bits := inst
|
||||
|
||||
// decode
|
||||
dec.io.inst := inst_q.io.deq.bits
|
||||
|
||||
// instruction queues
|
||||
io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain
|
||||
io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
|
||||
io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
|
||||
|
||||
io.inst.ld.bits := inst_q.io.deq.bits
|
||||
io.inst.co.bits := inst_q.io.deq.bits
|
||||
io.inst.st.bits := inst_q.io.deq.bits
|
||||
|
||||
// check if selected queue is ready
|
||||
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
|
||||
val deq_ready =
|
||||
MuxLookup(deq_sel,
|
||||
false.B, // default
|
||||
Array(
|
||||
"h_01".U -> io.inst.ld.ready,
|
||||
"h_02".U -> io.inst.st.ready,
|
||||
"h_04".U -> io.inst.co.ready
|
||||
)
|
||||
)
|
||||
|
||||
// dequeue instruction
|
||||
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
|
||||
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
when (state === sIdle && pulse) {
|
||||
printf("[Fetch] Launch\n")
|
||||
}
|
||||
// instruction
|
||||
when (inst_q.io.deq.fire()) {
|
||||
when (dec.io.isLoad) {
|
||||
printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
|
||||
}
|
||||
when (dec.io.isCompute) {
|
||||
printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
|
||||
}
|
||||
when (dec.io.isStore) {
|
||||
printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
/** ISAConstants.
|
||||
*
|
||||
* These constants are used for decoding (parsing) fields on instructions.
|
||||
*/
|
||||
trait ISAConstants
|
||||
{
|
||||
val INST_BITS = 128
|
||||
|
||||
val OP_BITS = 3
|
||||
|
||||
val M_DEP_BITS = 4
|
||||
val M_ID_BITS = 2
|
||||
val M_SRAM_OFFSET_BITS = 16
|
||||
val M_DRAM_OFFSET_BITS = 32
|
||||
val M_SIZE_BITS = 16
|
||||
val M_STRIDE_BITS = 16
|
||||
val M_PAD_BITS = 4
|
||||
|
||||
val C_UOP_BGN_BITS = 13
|
||||
val C_UOP_END_BITS = 14
|
||||
val C_ITER_BITS = 14
|
||||
val C_AIDX_BITS = 11
|
||||
val C_IIDX_BITS = 11
|
||||
val C_WIDX_BITS = 10
|
||||
val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
|
||||
val C_ALU_OP_BITS = 3
|
||||
val C_ALU_IMM_BITS = 16
|
||||
|
||||
val Y = true.B
|
||||
val N = false.B
|
||||
|
||||
val OP_L = 0.asUInt(OP_BITS.W)
|
||||
val OP_S = 1.asUInt(OP_BITS.W)
|
||||
val OP_G = 2.asUInt(OP_BITS.W)
|
||||
val OP_F = 3.asUInt(OP_BITS.W)
|
||||
val OP_A = 4.asUInt(OP_BITS.W)
|
||||
val OP_X = 5.asUInt(OP_BITS.W)
|
||||
|
||||
val ALU_OP_NUM = 5
|
||||
val ALU_OP = Enum(ALU_OP_NUM)
|
||||
|
||||
val M_ID_U = 0.asUInt(M_ID_BITS.W)
|
||||
val M_ID_W = 1.asUInt(M_ID_BITS.W)
|
||||
val M_ID_I = 2.asUInt(M_ID_BITS.W)
|
||||
val M_ID_A = 3.asUInt(M_ID_BITS.W)
|
||||
}
|
||||
|
||||
/** ISA.
|
||||
*
|
||||
* This is the VTA ISA, here we specify the cares and dont-cares that makes
|
||||
* decoding easier. Since instructions are quite long 128-bit, we could generate
|
||||
* these based on ISAConstants.
|
||||
*
|
||||
* FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
|
||||
* TODO: Add VXOR to clear accumulator
|
||||
*/
|
||||
object ISA {
|
||||
def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
|
||||
def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
|
||||
def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
|
||||
def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
|
||||
def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
|
||||
def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
|
||||
def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
|
||||
def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
|
||||
def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
|
||||
def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
|
||||
def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Load.
|
||||
*
|
||||
* Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
|
||||
* This module instantiate the TensorLoad unit which is in charge of
|
||||
* loading 1D and 2D tensors to scratchpads, so it can be used by
|
||||
* other modules such as Compute.
|
||||
*/
|
||||
class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val i_post = Input(Bool())
|
||||
val o_post = Output(Bool())
|
||||
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
|
||||
val inp_baddr = Input(UInt(mp.addrBits.W))
|
||||
val wgt_baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_rd = Vec(2, new VMEReadMaster)
|
||||
val inp = new TensorClient(tensorType = "inp")
|
||||
val wgt = new TensorClient(tensorType = "wgt")
|
||||
})
|
||||
val sIdle :: sSync :: sExe :: Nil = Enum(3)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
|
||||
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
|
||||
|
||||
val dec = Module(new LoadDecode)
|
||||
dec.io.inst := inst_q.io.deq.bits
|
||||
|
||||
val tensorType = Seq("inp", "wgt")
|
||||
val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
|
||||
val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
|
||||
|
||||
val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
|
||||
val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (start) {
|
||||
when (dec.io.isSync) {
|
||||
state := sSync
|
||||
} .elsewhen (dec.io.isInput || dec.io.isWeight) {
|
||||
state := sExe
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sSync) {
|
||||
state := sIdle
|
||||
}
|
||||
is (sExe) {
|
||||
when (done) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// instructions
|
||||
inst_q.io.enq <> io.inst
|
||||
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
|
||||
|
||||
// load tensor
|
||||
// [0] input (inp)
|
||||
// [1] weight (wgt)
|
||||
val ptr = Seq(io.inp_baddr, io.wgt_baddr)
|
||||
val tsor = Seq(io.inp, io.wgt)
|
||||
for (i <- 0 until 2) {
|
||||
tensorLoad(i).io.start := state === sIdle & start & tensorDec(i)
|
||||
tensorLoad(i).io.inst := inst_q.io.deq.bits
|
||||
tensorLoad(i).io.baddr := ptr(i)
|
||||
tensorLoad(i).io.tensor <> tsor(i)
|
||||
io.vme_rd(i) <> tensorLoad(i).io.vme_rd
|
||||
}
|
||||
|
||||
// semaphore
|
||||
s.io.spost := io.i_post
|
||||
s.io.swait := dec.io.pop_next & (state === sIdle & start)
|
||||
io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync))
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
// start
|
||||
when (state === sIdle && start) {
|
||||
when (dec.io.isSync) {
|
||||
printf("[Load] start sync\n")
|
||||
} .elsewhen (dec.io.isInput) {
|
||||
printf("[Load] start input\n")
|
||||
} .elsewhen (dec.io.isWeight) {
|
||||
printf("[Load] start weight\n")
|
||||
}
|
||||
}
|
||||
// done
|
||||
when (state === sSync) {
|
||||
printf("[Load] done sync\n")
|
||||
}
|
||||
when (state === sExe) {
|
||||
when (done) {
|
||||
when (dec.io.isInput) {
|
||||
printf("[Load] done input\n")
|
||||
} .elsewhen (dec.io.isWeight) {
|
||||
printf("[Load] done weight\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** UopMaster.
|
||||
*
|
||||
* Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
|
||||
* to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
|
||||
* used as an address to find the uop in the uop-scratchpad.
|
||||
*/
|
||||
class UopMaster(implicit p: Parameters) extends Bundle {
|
||||
val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
|
||||
val idx = ValidIO(UInt(addrBits.W))
|
||||
val data = Flipped(ValidIO(new UopDecode))
|
||||
override def cloneType = new UopMaster().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** UopClient.
|
||||
*
|
||||
* Uop interface used by a client module, i.e. LoadUop, to receive
|
||||
* a request from a master module, i.e. TensorAlu or TensorGemm.
|
||||
* The index (idx) is used as an address to find the uop in the uop-scratchpad.
|
||||
*/
|
||||
class UopClient(implicit p: Parameters) extends Bundle {
|
||||
val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
|
||||
val idx = Flipped(ValidIO(UInt(addrBits.W)))
|
||||
val data = ValidIO(new UopDecode)
|
||||
override def cloneType = new UopClient().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** LoadUop.
|
||||
*
|
||||
* Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
|
||||
* uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
|
||||
* group of 2 given the fact that the DRAM payload is 8-bytes. This module
|
||||
* should be modified later on to support different DRAM sizes efficiently.
|
||||
*/
|
||||
class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_rd = new VMEReadMaster
|
||||
val uop = new UopClient
|
||||
})
|
||||
val numUop = 2 // store two uops per sram word
|
||||
val uopBits = p(CoreKey).uopBits
|
||||
val uopDepth = p(CoreKey).uopMemDepth / numUop
|
||||
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
|
||||
val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
|
||||
val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
|
||||
val xrem = Reg(chiselTypeOf(dec.xsize))
|
||||
val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U
|
||||
val xmax = (1 << mp.lenBits).U
|
||||
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
|
||||
|
||||
val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
|
||||
val sizeIsEven = (dec.xsize % 2.U) === 0.U
|
||||
|
||||
val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
state := sReadCmd
|
||||
when (xsize < xmax) {
|
||||
xlen := xsize
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
xlen := xmax - 1.U
|
||||
xrem := xsize - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sReadCmd) {
|
||||
when (io.vme_rd.cmd.ready) {
|
||||
state := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.vme_rd.data.valid) {
|
||||
when(xcnt === xlen) {
|
||||
when (xrem === 0.U) {
|
||||
state := sIdle
|
||||
} .elsewhen (xrem < xmax) {
|
||||
state := sReadCmd
|
||||
xlen := xrem
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
xlen := xmax - 1.U
|
||||
xrem := xrem - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// read-from-dram
|
||||
when (state === sIdle) {
|
||||
when (offsetIsEven) {
|
||||
raddr := io.baddr + dec.dram_offset
|
||||
} .otherwise {
|
||||
raddr := io.baddr + dec.dram_offset - 4.U
|
||||
}
|
||||
} .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
|
||||
raddr := raddr + xmax_bytes
|
||||
}
|
||||
|
||||
io.vme_rd.cmd.valid := state === sReadCmd
|
||||
io.vme_rd.cmd.bits.addr := raddr
|
||||
io.vme_rd.cmd.bits.len := xlen
|
||||
|
||||
io.vme_rd.data.ready := state === sReadData
|
||||
|
||||
when (state =/= sReadData) {
|
||||
xcnt := 0.U
|
||||
} .elsewhen (io.vme_rd.data.fire()) {
|
||||
xcnt := xcnt + 1.U
|
||||
}
|
||||
|
||||
val waddr = Reg(UInt(log2Ceil(uopDepth).W))
|
||||
when (state === sIdle) {
|
||||
waddr := dec.sram_offset >> log2Ceil(numUop)
|
||||
} .elsewhen (io.vme_rd.data.fire()) {
|
||||
waddr := waddr + 1.U
|
||||
}
|
||||
|
||||
val wdata = Wire(Vec(numUop, UInt(uopBits.W)))
|
||||
val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
|
||||
val wmask = Reg(Vec(numUop, Bool()))
|
||||
|
||||
when (offsetIsEven) {
|
||||
when (sizeIsEven) {
|
||||
wmask := "b_11".U.asTypeOf(wmask)
|
||||
} .elsewhen (io.vme_rd.cmd.fire()) {
|
||||
when (dec.xsize === 1.U) {
|
||||
wmask := "b_01".U.asTypeOf(wmask)
|
||||
} .otherwise {
|
||||
wmask := "b_11".U.asTypeOf(wmask)
|
||||
}
|
||||
} .elsewhen (io.vme_rd.data.fire()) {
|
||||
when (xcnt === xlen - 1.U) {
|
||||
wmask := "b_01".U.asTypeOf(wmask)
|
||||
} .otherwise {
|
||||
wmask := "b_11".U.asTypeOf(wmask)
|
||||
}
|
||||
}
|
||||
} .otherwise {
|
||||
when (io.vme_rd.cmd.fire()) {
|
||||
wmask := "b_10".U.asTypeOf(wmask)
|
||||
} .elsewhen (io.vme_rd.data.fire()) {
|
||||
when (sizeIsEven && xcnt === xlen - 1.U) {
|
||||
wmask := "b_01".U.asTypeOf(wmask)
|
||||
} .otherwise {
|
||||
wmask := "b_11".U.asTypeOf(wmask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wdata := io.vme_rd.data.bits.asTypeOf(wdata)
|
||||
when (io.vme_rd.data.fire()) {
|
||||
mem.write(waddr, wdata, wmask)
|
||||
}
|
||||
|
||||
// read-from-sram
|
||||
io.uop.data.valid := RegNext(io.uop.idx.valid)
|
||||
|
||||
val sIdx = io.uop.idx.bits % numUop.U
|
||||
val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
|
||||
val memRead = mem.read(rIdx, io.uop.idx.valid)
|
||||
val sWord = memRead.asUInt.asTypeOf(wdata)
|
||||
val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits)
|
||||
|
||||
io.uop.data.bits <> sUop
|
||||
|
||||
// done
|
||||
io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
when (io.vme_rd.cmd.fire()) {
|
||||
printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
/** Semaphore.
|
||||
*
|
||||
* This semaphore is used instead of push/pop fifo, used in the initial
|
||||
* version of VTA. This semaphore is incremented (spost) or decremented (swait)
|
||||
* depending on the push and pop fields on instructions to prevent RAW and WAR
|
||||
* hazards.
|
||||
*/
|
||||
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val spost = Input(Bool())
|
||||
val swait = Input(Bool())
|
||||
val sready = Output(Bool())
|
||||
})
|
||||
val cnt = RegInit(counterInitValue.U(counterBits.W))
|
||||
when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
|
||||
when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
|
||||
io.sready := cnt =/= 0.U
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Store.
|
||||
*
|
||||
* Store results back to memory (DRAM) from scratchpads (SRAMs).
|
||||
* This module instantiate the TensorStore unit which is in charge
|
||||
* of storing 1D and 2D tensors to main memory.
|
||||
*/
|
||||
class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val i_post = Input(Bool())
|
||||
val o_post = Output(Bool())
|
||||
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
|
||||
val out_baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_wr = new VMEWriteMaster
|
||||
val out = new TensorClient(tensorType = "out")
|
||||
})
|
||||
val sIdle :: sSync :: sExe :: Nil = Enum(3)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
|
||||
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
|
||||
|
||||
val dec = Module(new StoreDecode)
|
||||
dec.io.inst := inst_q.io.deq.bits
|
||||
|
||||
val tensorStore = Module(new TensorStore(tensorType = "out"))
|
||||
|
||||
val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B)
|
||||
val done = tensorStore.io.done
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (start) {
|
||||
when (dec.io.isSync) {
|
||||
state := sSync
|
||||
} .elsewhen (dec.io.isStore) {
|
||||
state := sExe
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sSync) {
|
||||
state := sIdle
|
||||
}
|
||||
is (sExe) {
|
||||
when (done) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// instructions
|
||||
inst_q.io.enq <> io.inst
|
||||
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
|
||||
|
||||
// store
|
||||
tensorStore.io.start := state === sIdle & start & dec.io.isStore
|
||||
tensorStore.io.inst := inst_q.io.deq.bits
|
||||
tensorStore.io.baddr := io.out_baddr
|
||||
io.vme_wr <> tensorStore.io.vme_wr
|
||||
tensorStore.io.tensor <> io.out
|
||||
|
||||
// semaphore
|
||||
s.io.spost := io.i_post
|
||||
s.io.swait := dec.io.pop_prev & (state === sIdle & start)
|
||||
io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
// start
|
||||
when (state === sIdle && start) {
|
||||
when (dec.io.isSync) {
|
||||
printf("[Store] start sync\n")
|
||||
} .elsewhen (dec.io.isStore) {
|
||||
printf("[Store] start\n")
|
||||
}
|
||||
}
|
||||
// done
|
||||
when (state === sSync) {
|
||||
printf("[Store] done sync\n")
|
||||
}
|
||||
when (state === sExe) {
|
||||
when (done) {
|
||||
printf("[Store] done\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,295 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
|
||||
/** ALU datapath */
|
||||
class Alu(implicit p: Parameters) extends Module {
|
||||
val aluBits = p(CoreKey).accBits
|
||||
val io = IO(new Bundle {
|
||||
val opcode = Input(UInt(C_ALU_OP_BITS.W))
|
||||
val a = Input(SInt(aluBits.W))
|
||||
val b = Input(SInt(aluBits.W))
|
||||
val y = Output(SInt(aluBits.W))
|
||||
})
|
||||
|
||||
// FIXME: the following three will change once we support properly SHR and SHL
|
||||
val ub = io.b.asUInt
|
||||
val width = log2Ceil(aluBits)
|
||||
val m = ~ub(width - 1, 0) + 1.U
|
||||
|
||||
val n = ub(width - 1, 0)
|
||||
val fop = Seq(Mux(io.a < io.b, io.a, io.b),
|
||||
Mux(io.a < io.b, io.b, io.a),
|
||||
io.a + io.b,
|
||||
io.a >> n,
|
||||
io.a << m)
|
||||
|
||||
val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
|
||||
io.y := MuxLookup(io.opcode, io.a, opmux)
|
||||
}
|
||||
|
||||
/** Pipelined ALU */
|
||||
class AluReg(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val opcode = Input(UInt(C_ALU_OP_BITS.W))
|
||||
val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
|
||||
val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
|
||||
val y = ValidIO(UInt(p(CoreKey).accBits.W))
|
||||
})
|
||||
val alu = Module(new Alu)
|
||||
val rA = RegEnable(io.a.bits, io.a.valid)
|
||||
val rB = RegEnable(io.b.bits, io.b.valid)
|
||||
val valid = RegNext(io.b.valid)
|
||||
|
||||
alu.io.opcode := io.opcode
|
||||
|
||||
// register input
|
||||
alu.io.a := rA.asSInt
|
||||
alu.io.b := rB.asSInt
|
||||
|
||||
// output
|
||||
io.y.valid := valid
|
||||
io.y.bits := alu.io.y.asUInt
|
||||
}
|
||||
|
||||
/** Vector of pipeline ALUs */
|
||||
class AluVector(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val opcode = Input(UInt(C_ALU_OP_BITS.W))
|
||||
val acc_a = new TensorMasterData(tensorType = "acc")
|
||||
val acc_b = new TensorMasterData(tensorType = "acc")
|
||||
val acc_y = new TensorClientData(tensorType = "acc")
|
||||
val out = new TensorClientData(tensorType = "out")
|
||||
})
|
||||
val blockOut = p(CoreKey).blockOut
|
||||
val f = Seq.fill(blockOut)(Module(new AluReg))
|
||||
val valid = Wire(Vec(blockOut, Bool()))
|
||||
for (i <- 0 until blockOut) {
|
||||
f(i).io.opcode := io.opcode
|
||||
f(i).io.a.valid := io.acc_a.data.valid
|
||||
f(i).io.a.bits := io.acc_a.data.bits(0)(i)
|
||||
f(i).io.b.valid := io.acc_b.data.valid
|
||||
f(i).io.b.bits := io.acc_b.data.bits(0)(i)
|
||||
valid(i) := f(i).io.y.valid
|
||||
io.acc_y.data.bits(0)(i) := f(i).io.y.bits
|
||||
io.out.data.bits(0)(i) := f(i).io.y.bits
|
||||
}
|
||||
io.acc_y.data.valid := valid.asUInt.andR
|
||||
io.out.data.valid := valid.asUInt.andR
|
||||
}
|
||||
|
||||
/** TensorAlu.
|
||||
*
|
||||
* This unit instantiate the ALU vector unit (AluVector) and go over the
|
||||
* micro-ops (uops) which are used to read the source operands (vectors)
|
||||
* from the acc-scratchpad and then they are written back the same
|
||||
* acc-scratchpad.
|
||||
*/
|
||||
class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val uop = new UopMaster
|
||||
val acc = new TensorMaster(tensorType = "acc")
|
||||
val out = new TensorMaster(tensorType = "out")
|
||||
})
|
||||
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
|
||||
val state = RegInit(sIdle)
|
||||
val alu = Module(new AluVector)
|
||||
val dec = io.inst.asTypeOf(new AluDecode)
|
||||
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
|
||||
val uop_end = dec.uop_end
|
||||
val uop_dst = Reg(chiselTypeOf(dec.uop_end))
|
||||
val uop_src = Reg(chiselTypeOf(dec.uop_end))
|
||||
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
|
||||
val dst_o = Reg(chiselTypeOf(dec.uop_end))
|
||||
val src_o = Reg(chiselTypeOf(dec.uop_end))
|
||||
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
|
||||
val dst_i = Reg(chiselTypeOf(dec.uop_end))
|
||||
val src_i = Reg(chiselTypeOf(dec.uop_end))
|
||||
val done =
|
||||
state === sExe &
|
||||
alu.io.out.data.valid &
|
||||
(cnt_o === dec.lp_0 - 1.U) &
|
||||
(cnt_i === dec.lp_1 - 1.U) &
|
||||
(uop_idx === uop_end - 1.U)
|
||||
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
state := sReadUop
|
||||
}
|
||||
}
|
||||
is (sReadUop) {
|
||||
state := sComputeIdx
|
||||
}
|
||||
is (sComputeIdx) {
|
||||
state := sReadTensorA
|
||||
}
|
||||
is (sReadTensorA) {
|
||||
state := sReadTensorB
|
||||
}
|
||||
is (sReadTensorB) {
|
||||
state := sExe
|
||||
}
|
||||
is (sExe) {
|
||||
when (alu.io.out.data.valid) {
|
||||
when ((cnt_o === dec.lp_0 - 1.U) &&
|
||||
(cnt_i === dec.lp_1 - 1.U) &&
|
||||
(uop_idx === uop_end - 1.U)) {
|
||||
state := sIdle
|
||||
} .otherwise {
|
||||
state := sReadUop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle ||
|
||||
(state === sExe &&
|
||||
alu.io.out.data.valid &&
|
||||
uop_idx === uop_end - 1.U)) {
|
||||
uop_idx := dec.uop_begin
|
||||
} .elsewhen (state === sExe && alu.io.out.data.valid) {
|
||||
uop_idx := uop_idx + 1.U
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
cnt_o := 0.U
|
||||
dst_o := 0.U
|
||||
src_o := 0.U
|
||||
} .elsewhen (state === sExe &&
|
||||
alu.io.out.data.valid &&
|
||||
uop_idx === uop_end - 1.U &&
|
||||
cnt_i === dec.lp_1 - 1.U) {
|
||||
cnt_o := cnt_o + 1.U
|
||||
dst_o := dst_o + dec.dst_0
|
||||
src_o := src_o + dec.src_0
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
cnt_i := 0.U
|
||||
dst_i := 0.U
|
||||
src_i := 0.U
|
||||
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
|
||||
cnt_i := 0.U
|
||||
dst_i := dst_o
|
||||
src_i := src_o
|
||||
} .elsewhen (state === sExe &&
|
||||
alu.io.out.data.valid &&
|
||||
uop_idx === uop_end - 1.U) {
|
||||
cnt_i := cnt_i + 1.U
|
||||
dst_i := dst_i + dec.dst_1
|
||||
src_i := src_i + dec.src_1
|
||||
}
|
||||
|
||||
when (state === sComputeIdx && io.uop.data.valid) {
|
||||
uop_dst := io.uop.data.bits.u0 + dst_i
|
||||
uop_src := io.uop.data.bits.u1 + src_i
|
||||
}
|
||||
|
||||
// uop
|
||||
io.uop.idx.valid := state === sReadUop
|
||||
io.uop.idx.bits := uop_idx
|
||||
|
||||
// acc_i
|
||||
io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
|
||||
io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
|
||||
|
||||
// imm
|
||||
val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
|
||||
tensorImm.data.valid := state === sReadTensorB
|
||||
tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
|
||||
|
||||
// alu
|
||||
val isSHR = dec.alu_op === ALU_OP(3)
|
||||
val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
|
||||
val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
|
||||
alu.io.opcode := fixme_alu_op
|
||||
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
|
||||
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
|
||||
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
|
||||
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
|
||||
|
||||
// acc_o
|
||||
io.acc.wr.valid := alu.io.acc_y.data.valid
|
||||
io.acc.wr.bits.idx := uop_dst
|
||||
io.acc.wr.bits.data <> alu.io.acc_y.data.bits
|
||||
|
||||
// out
|
||||
io.out.wr.valid := alu.io.out.data.valid
|
||||
io.out.wr.bits.idx := uop_dst
|
||||
io.out.wr.bits.data <> alu.io.out.data.bits
|
||||
io.out.tieoffRead() // write-only
|
||||
|
||||
io.done := done
|
||||
|
||||
if (debug) {
|
||||
|
||||
when (state === sReadUop) {
|
||||
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
|
||||
}
|
||||
|
||||
when (state === sReadTensorA) {
|
||||
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
|
||||
}
|
||||
|
||||
when (state === sIdle && io.start) {
|
||||
printf(p"[TensorAlu] decode:$dec\n")
|
||||
}
|
||||
|
||||
alu.io.acc_a.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (alu.io.acc_a.data.valid) {
|
||||
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alu.io.acc_b.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (alu.io.acc_b.data.valid) {
|
||||
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alu.io.acc_y.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (alu.io.acc_y.data.valid) {
|
||||
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alu.io.out.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (alu.io.out.data.valid) {
|
||||
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,364 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import chisel3.experimental._
|
||||
import vta.util.config._
|
||||
import scala.math.pow
|
||||
|
||||
/** Pipelined multiply and accumulate */
|
||||
class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module {
|
||||
require (cBits >= dataBits * 2)
|
||||
require (outBits >= dataBits * 2)
|
||||
val io = IO(new Bundle {
|
||||
val a = Input(SInt(dataBits.W))
|
||||
val b = Input(SInt(dataBits.W))
|
||||
val c = Input(SInt(cBits.W))
|
||||
val y = Output(SInt(outBits.W))
|
||||
})
|
||||
val mult = Wire(SInt(cBits.W))
|
||||
val add = Wire(SInt(outBits.W))
|
||||
val rA = RegNext(io.a)
|
||||
val rB = RegNext(io.b)
|
||||
val rC = RegNext(io.c)
|
||||
mult := rA * rB
|
||||
add := rC + mult
|
||||
io.y := add
|
||||
}
|
||||
|
||||
/** Pipelined adder */
|
||||
class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
|
||||
require (outBits >= dataBits)
|
||||
val io = IO(new Bundle {
|
||||
val a = Input(SInt(dataBits.W))
|
||||
val b = Input(SInt(dataBits.W))
|
||||
val y = Output(SInt(outBits.W))
|
||||
})
|
||||
val add = Wire(SInt(outBits.W))
|
||||
val rA = RegNext(io.a)
|
||||
val rB = RegNext(io.b)
|
||||
add := rA + rB
|
||||
io.y := add
|
||||
}
|
||||
|
||||
/** Pipelined DotProduct based on MAC and Adder */
|
||||
class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
|
||||
val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
|
||||
require(size >= 4 && isPow2(size), errMsg)
|
||||
val b = dataBits * 2
|
||||
val outBits = b + log2Ceil(size) + 1
|
||||
val io = IO(new Bundle {
|
||||
val a = Input(Vec(size, SInt(dataBits.W)))
|
||||
val b = Input(Vec(size, SInt(dataBits.W)))
|
||||
val y = Output(SInt(outBits.W))
|
||||
})
|
||||
val p = log2Ceil(size/2)
|
||||
val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
|
||||
val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
|
||||
val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
|
||||
val m = Seq.tabulate(2)(i =>
|
||||
Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
|
||||
)
|
||||
val a = Seq.tabulate(p)(i =>
|
||||
Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
|
||||
)
|
||||
|
||||
for (i <- 0 until log2Ceil(size)) {
|
||||
for (j <- 0 until s(i)) {
|
||||
if (i == 0) {
|
||||
m(i)(j).io.a := io.a(j)
|
||||
m(i)(j).io.b := io.b(j)
|
||||
m(i)(j).io.c := 0.S
|
||||
m(i + 1)(j).io.a := da(j)
|
||||
m(i + 1)(j).io.b := db(j)
|
||||
m(i + 1)(j).io.c := m(i)(j).io.y
|
||||
} else if (i == 1) {
|
||||
a(i - 1)(j).io.a := m(i)(2*j).io.y
|
||||
a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
|
||||
} else {
|
||||
a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
|
||||
a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
|
||||
}
|
||||
}
|
||||
}
|
||||
io.y := a(p-1)(0).io.y
|
||||
}
|
||||
|
||||
/** Perform matric-vector-multiplication based on DotProduct */
|
||||
class MatrixVectorCore(implicit p: Parameters) extends Module {
|
||||
val accBits = p(CoreKey).accBits
|
||||
val size = p(CoreKey).blockOut
|
||||
val dataBits = p(CoreKey).inpBits
|
||||
val io = IO(new Bundle{
|
||||
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
|
||||
val inp = new TensorMasterData(tensorType = "inp")
|
||||
val wgt = new TensorMasterData(tensorType = "wgt")
|
||||
val acc_i = new TensorMasterData(tensorType = "acc")
|
||||
val acc_o = new TensorClientData(tensorType = "acc")
|
||||
val out = new TensorClientData(tensorType = "out")
|
||||
})
|
||||
val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
|
||||
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
|
||||
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
|
||||
val vld = Wire(Vec(size, Bool()))
|
||||
|
||||
for (i <- 0 until size) {
|
||||
acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset
|
||||
acc(i).io.enq.bits := io.acc_i.data.bits(0)(i)
|
||||
for (j <- 0 until size) {
|
||||
dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt
|
||||
dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt
|
||||
}
|
||||
add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y
|
||||
io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt)
|
||||
io.out.data.bits(0)(i) := add(i).asUInt
|
||||
vld(i) := acc(i).io.deq.valid
|
||||
}
|
||||
io.acc_o.data.valid := vld.asUInt.andR | io.reset
|
||||
io.out.data.valid := vld.asUInt.andR
|
||||
}
|
||||
|
||||
/** TensorGemm.
|
||||
*
|
||||
* This unit instantiate the MatrixVectorCore and go over the
|
||||
* micro-ops (uops) which are used to read inputs, weights and biases,
|
||||
* and writes results back to the acc and out scratchpads.
|
||||
*
|
||||
* Also, the TensorGemm uses the reset field in the Gemm instruction to
|
||||
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
|
||||
*/
|
||||
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val uop = new UopMaster
|
||||
val inp = new TensorMaster(tensorType = "inp")
|
||||
val wgt = new TensorMaster(tensorType = "wgt")
|
||||
val acc = new TensorMaster(tensorType = "acc")
|
||||
val out = new TensorMaster(tensorType = "out")
|
||||
})
|
||||
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
|
||||
val state = RegInit(sIdle)
|
||||
val mvc = Module(new MatrixVectorCore)
|
||||
val dec = io.inst.asTypeOf(new GemmDecode)
|
||||
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
|
||||
val uop_end = dec.uop_end
|
||||
val uop_acc = Reg(chiselTypeOf(dec.uop_end))
|
||||
val uop_inp = Reg(chiselTypeOf(dec.uop_end))
|
||||
val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
|
||||
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
|
||||
val acc_o = Reg(chiselTypeOf(dec.uop_end))
|
||||
val inp_o = Reg(chiselTypeOf(dec.uop_end))
|
||||
val wgt_o = Reg(chiselTypeOf(dec.uop_end))
|
||||
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
|
||||
val acc_i = Reg(chiselTypeOf(dec.uop_end))
|
||||
val inp_i = Reg(chiselTypeOf(dec.uop_end))
|
||||
val wgt_i = Reg(chiselTypeOf(dec.uop_end))
|
||||
val pBits = log2Ceil(p(CoreKey).blockOut) + 1
|
||||
val inflight = Reg(UInt(pBits.W))
|
||||
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
|
||||
val done = inflight === 0.U &
|
||||
((state === sExe &
|
||||
cnt_o === dec.lp_0 - 1.U &
|
||||
cnt_i === dec.lp_1 - 1.U &
|
||||
uop_idx === uop_end - 1.U &
|
||||
inflight === 0.U) |
|
||||
state === sWait)
|
||||
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
state := sReadUop
|
||||
}
|
||||
}
|
||||
is (sReadUop) {
|
||||
state := sComputeIdx
|
||||
}
|
||||
is (sComputeIdx) {
|
||||
state := sReadTensor
|
||||
}
|
||||
is (sReadTensor) {
|
||||
state := sExe
|
||||
}
|
||||
is (sExe) {
|
||||
when ((cnt_o === dec.lp_0 - 1.U) &&
|
||||
(cnt_i === dec.lp_1 - 1.U) &&
|
||||
(uop_idx === uop_end - 1.U)) {
|
||||
when (inflight =/= 0.U) {
|
||||
state := sWait
|
||||
} .otherwise {
|
||||
state := sIdle
|
||||
}
|
||||
} .otherwise {
|
||||
state := sReadUop
|
||||
}
|
||||
}
|
||||
is (sWait) {
|
||||
when (inflight === 0.U) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
inflight := 0.U
|
||||
} .elsewhen (!dec.reset) {
|
||||
when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check
|
||||
inflight := inflight + 1.U
|
||||
} .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check
|
||||
inflight := inflight - 1.U
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle ||
|
||||
(state === sExe &&
|
||||
uop_idx === uop_end - 1.U)) {
|
||||
uop_idx := dec.uop_begin
|
||||
} .elsewhen (state === sExe) {
|
||||
uop_idx := uop_idx + 1.U
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
cnt_o := 0.U
|
||||
acc_o := 0.U
|
||||
inp_o := 0.U
|
||||
wgt_o := 0.U
|
||||
} .elsewhen (state === sExe &&
|
||||
uop_idx === uop_end - 1.U &&
|
||||
cnt_i === dec.lp_1 - 1.U) {
|
||||
cnt_o := cnt_o + 1.U
|
||||
acc_o := acc_o + dec.acc_0
|
||||
inp_o := inp_o + dec.inp_0
|
||||
wgt_o := wgt_o + dec.wgt_0
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
cnt_i := 0.U
|
||||
acc_i := 0.U
|
||||
inp_i := 0.U
|
||||
wgt_i := 0.U
|
||||
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
|
||||
cnt_i := 0.U
|
||||
acc_i := acc_o
|
||||
inp_i := inp_o
|
||||
wgt_i := wgt_o
|
||||
} .elsewhen (state === sExe &&
|
||||
uop_idx === uop_end - 1.U) {
|
||||
cnt_i := cnt_i + 1.U
|
||||
acc_i := acc_i + dec.acc_1
|
||||
inp_i := inp_i + dec.inp_1
|
||||
wgt_i := wgt_i + dec.wgt_1
|
||||
}
|
||||
|
||||
when (state === sComputeIdx && io.uop.data.valid) {
|
||||
uop_acc := io.uop.data.bits.u0 + acc_i
|
||||
uop_inp := io.uop.data.bits.u1 + inp_i
|
||||
uop_wgt := io.uop.data.bits.u2 + wgt_i
|
||||
}
|
||||
|
||||
wrpipe.io.enq.valid := state === sExe & ~dec.reset
|
||||
wrpipe.io.enq.bits := uop_acc
|
||||
|
||||
// uop
|
||||
io.uop.idx.valid := state === sReadUop
|
||||
io.uop.idx.bits := uop_idx
|
||||
|
||||
// inp
|
||||
io.inp.rd.idx.valid := state === sReadTensor
|
||||
io.inp.rd.idx.bits := uop_inp
|
||||
io.inp.tieoffWrite() // read-only
|
||||
|
||||
// wgt
|
||||
io.wgt.rd.idx.valid := state === sReadTensor
|
||||
io.wgt.rd.idx.bits := uop_wgt
|
||||
io.wgt.tieoffWrite() // read-only
|
||||
|
||||
// acc_i
|
||||
io.acc.rd.idx.valid := state === sReadTensor
|
||||
io.acc.rd.idx.bits := uop_acc
|
||||
|
||||
// mvc
|
||||
mvc.io.reset := dec.reset & state === sExe
|
||||
mvc.io.inp.data <> io.inp.rd.data
|
||||
mvc.io.wgt.data <> io.wgt.rd.data
|
||||
mvc.io.acc_i.data <> io.acc.rd.data
|
||||
|
||||
// acc_o
|
||||
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
|
||||
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
|
||||
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
|
||||
|
||||
// out
|
||||
io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
|
||||
io.out.wr.bits.idx := wrpipe.io.deq.bits
|
||||
io.out.wr.bits.data <> mvc.io.out.data.bits
|
||||
io.out.tieoffRead() // write-only
|
||||
|
||||
io.done := done
|
||||
|
||||
if (debug) {
|
||||
when (state === sReadUop && ~dec.reset) {
|
||||
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
|
||||
}
|
||||
|
||||
when (state === sReadTensor && ~dec.reset) {
|
||||
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
|
||||
}
|
||||
|
||||
io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
|
||||
when (io.inp.rd.data.valid && ~dec.reset) {
|
||||
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
|
||||
}
|
||||
}
|
||||
|
||||
io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
|
||||
when (io.wgt.rd.data.valid && ~dec.reset) {
|
||||
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
|
||||
}
|
||||
}
|
||||
|
||||
io.acc.rd.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (io.acc.rd.data.valid && ~dec.reset) {
|
||||
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mvc.io.acc_o.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (mvc.io.acc_o.data.valid && ~dec.reset) {
|
||||
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mvc.io.out.data.bits.foreach { tensor =>
|
||||
tensor.zipWithIndex.foreach { case(elem, i) =>
|
||||
when (mvc.io.out.data.valid && ~dec.reset) {
|
||||
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,278 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** TensorStore.
|
||||
*
|
||||
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
|
||||
* scratchpads (SRAM). Also, there is support for zero padding, while
|
||||
* doing the load. Zero-padding works on the y and x axis, and it is
|
||||
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
|
||||
* handling the way tensors are stored on the scratchpads.
|
||||
*/
|
||||
class TensorLoad(tensorType: String = "none", debug: Boolean = false)
|
||||
(implicit p: Parameters) extends Module {
|
||||
val tp = new TensorParams(tensorType)
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_rd = new VMEReadMaster
|
||||
val tensor = new TensorClient(tensorType)
|
||||
})
|
||||
val sizeFactor = tp.tensorLength * tp.numMemBlock
|
||||
val strideFactor = tp.tensorLength * tp.tensorWidth
|
||||
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor))
|
||||
val dataCtrlDone = RegInit(false.B)
|
||||
val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
|
||||
val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
|
||||
val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor))
|
||||
val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor))
|
||||
|
||||
val tag = Reg(UInt(8.W))
|
||||
val set = Reg(UInt(8.W))
|
||||
|
||||
val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
when (dec.ypad_0 =/= 0.U) {
|
||||
state := sYPad0
|
||||
} .elsewhen (dec.xpad_0 =/= 0.U) {
|
||||
state := sXPad0
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sYPad0) {
|
||||
when (yPadCtrl0.io.done) {
|
||||
when (dec.xpad_0 =/= 0.U) {
|
||||
state := sXPad0
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sXPad0) {
|
||||
when (xPadCtrl0.io.done) {
|
||||
state := sReadCmd
|
||||
}
|
||||
}
|
||||
is (sReadCmd) {
|
||||
when (io.vme_rd.cmd.ready) {
|
||||
state := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.vme_rd.data.valid) {
|
||||
when (dataCtrl.io.done) {
|
||||
when (dec.xpad_1 =/= 0.U) {
|
||||
state := sXPad1
|
||||
} .elsewhen (dec.ypad_1 =/= 0.U) {
|
||||
state := sYPad1
|
||||
} .otherwise {
|
||||
state := sIdle
|
||||
}
|
||||
} .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
|
||||
when (dec.xpad_1 =/= 0.U) {
|
||||
state := sXPad1
|
||||
} .elsewhen (dec.xpad_0 =/= 0.U) {
|
||||
state := sXPad0
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sXPad1) {
|
||||
when (xPadCtrl1.io.done) {
|
||||
when (dataCtrlDone) {
|
||||
when (dec.ypad_1 =/= 0.U) {
|
||||
state := sYPad1
|
||||
} .otherwise {
|
||||
state := sIdle
|
||||
}
|
||||
} .otherwise {
|
||||
when (dec.xpad_0 =/= 0.U) {
|
||||
state := sXPad0
|
||||
} .otherwise {
|
||||
state := sReadCmd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sYPad1) {
|
||||
when (yPadCtrl1.io.done && dataCtrlDone) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// data controller
|
||||
dataCtrl.io.start := state === sIdle & io.start
|
||||
dataCtrl.io.inst := io.inst
|
||||
dataCtrl.io.baddr := io.baddr
|
||||
dataCtrl.io.xinit := io.vme_rd.cmd.fire()
|
||||
dataCtrl.io.xupdate := io.vme_rd.data.fire()
|
||||
dataCtrl.io.yupdate := io.vme_rd.data.fire()
|
||||
|
||||
when (state === sIdle) {
|
||||
dataCtrlDone := false.B
|
||||
} .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
|
||||
dataCtrlDone := true.B
|
||||
}
|
||||
|
||||
// pad
|
||||
yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
|
||||
|
||||
yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
|
||||
((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
|
||||
(state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
|
||||
|
||||
xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
|
||||
((state === sIdle & io.start) |
|
||||
(state === sYPad0 & yPadCtrl0.io.done) |
|
||||
(io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
|
||||
(state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
|
||||
|
||||
xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
|
||||
((dataCtrl.io.done) |
|
||||
(~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
|
||||
|
||||
yPadCtrl0.io.inst := io.inst
|
||||
yPadCtrl1.io.inst := io.inst
|
||||
xPadCtrl0.io.inst := io.inst
|
||||
xPadCtrl1.io.inst := io.inst
|
||||
|
||||
// read-from-dram
|
||||
io.vme_rd.cmd.valid := state === sReadCmd
|
||||
io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
|
||||
io.vme_rd.cmd.bits.len := dataCtrl.io.len
|
||||
|
||||
io.vme_rd.data.ready := state === sReadData
|
||||
|
||||
// write-to-sram
|
||||
val isZeroPad = state === sYPad0 |
|
||||
state === sXPad0 |
|
||||
state === sXPad1 |
|
||||
state === sYPad1
|
||||
|
||||
when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
|
||||
tag := 0.U
|
||||
} .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
|
||||
tag := tag + 1.U
|
||||
}
|
||||
|
||||
when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
|
||||
set := 0.U
|
||||
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
|
||||
set := set + 1.U
|
||||
}
|
||||
|
||||
val waddr_cur = Reg(UInt(tp.memAddrBits.W))
|
||||
val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
|
||||
when (state === sIdle) {
|
||||
waddr_cur := dec.sram_offset
|
||||
waddr_nxt := dec.sram_offset
|
||||
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
|
||||
waddr_cur := waddr_cur + 1.U
|
||||
} .elsewhen (dataCtrl.io.stride) {
|
||||
waddr_cur := waddr_nxt + dec.xsize
|
||||
waddr_nxt := waddr_nxt + dec.xsize
|
||||
}
|
||||
|
||||
val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
|
||||
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
|
||||
val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
|
||||
val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
|
||||
no_mask.foreach { m => m := true.B }
|
||||
|
||||
for (i <- 0 until tp.tensorLength) {
|
||||
for (j <- 0 until tp.numMemBlock) {
|
||||
wmask(i)(j) := tag === j.U
|
||||
wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
|
||||
}
|
||||
val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
|
||||
val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
|
||||
val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
|
||||
val muxWdata = Mux(state === sIdle, tdata, wdata(i))
|
||||
val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
|
||||
when (muxWen) {
|
||||
tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
|
||||
}
|
||||
}
|
||||
|
||||
// read-from-sram
|
||||
val rvalid = RegNext(io.tensor.rd.idx.valid)
|
||||
io.tensor.rd.data.valid := rvalid
|
||||
|
||||
val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
|
||||
rdata.zipWithIndex.foreach { case(r, i) =>
|
||||
io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
|
||||
}
|
||||
|
||||
// done
|
||||
val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
|
||||
val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
|
||||
val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
|
||||
io.done := done_no_pad | done_x_pad | done_y_pad
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
if (tensorType == "inp") {
|
||||
when (io.vme_rd.cmd.fire()) {
|
||||
printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
|
||||
}
|
||||
when (state === sYPad0) {
|
||||
printf("[TensorLoad] [inp] sYPad0\n")
|
||||
}
|
||||
when (state === sYPad1) {
|
||||
printf("[TensorLoad] [inp] sYPad1\n")
|
||||
}
|
||||
when (state === sXPad0) {
|
||||
printf("[TensorLoad] [inp] sXPad0\n")
|
||||
}
|
||||
when (state === sXPad1) {
|
||||
printf("[TensorLoad] [inp] sXPad1\n")
|
||||
}
|
||||
} else if (tensorType == "wgt") {
|
||||
when (io.vme_rd.cmd.fire()) {
|
||||
printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
|
||||
}
|
||||
} else if (tensorType == "acc") {
|
||||
when (io.vme_rd.cmd.fire()) {
|
||||
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,224 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** TensorStore.
|
||||
*
|
||||
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
|
||||
*/
|
||||
class TensorStore(tensorType: String = "true", debug: Boolean = false)
|
||||
(implicit p: Parameters) extends Module {
|
||||
val tp = new TensorParams(tensorType)
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val baddr = Input(UInt(mp.addrBits.W))
|
||||
val vme_wr = new VMEWriteMaster
|
||||
val tensor = new TensorClient(tensorType)
|
||||
})
|
||||
val tensorLength = tp.tensorLength
|
||||
val tensorWidth = tp.tensorWidth
|
||||
val tensorElemBits = tp.tensorElemBits
|
||||
val memBlockBits = tp.memBlockBits
|
||||
val memDepth = tp.memDepth
|
||||
val numMemBlock = tp.numMemBlock
|
||||
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
|
||||
val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
|
||||
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
|
||||
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
|
||||
val xrem = Reg(chiselTypeOf(dec.xsize))
|
||||
val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
|
||||
val xmax = (1 << mp.lenBits).U
|
||||
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
|
||||
val ycnt = Reg(chiselTypeOf(dec.ysize))
|
||||
val ysize = dec.ysize
|
||||
val tag = Reg(UInt(8.W))
|
||||
val set = Reg(UInt(8.W))
|
||||
|
||||
val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
// control
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
state := sWriteCmd
|
||||
when (xsize < xmax) {
|
||||
xlen := xsize
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
xlen := xmax - 1.U
|
||||
xrem := xsize - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sWriteCmd) {
|
||||
when (io.vme_wr.cmd.ready) {
|
||||
state := sWriteData
|
||||
}
|
||||
}
|
||||
is (sWriteData) {
|
||||
when (io.vme_wr.data.ready) {
|
||||
when (xcnt === xlen) {
|
||||
state := sWriteAck
|
||||
} .elsewhen (tag === (numMemBlock - 1).U) {
|
||||
state := sReadMem
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sReadMem) {
|
||||
state := sWriteData
|
||||
}
|
||||
is (sWriteAck) {
|
||||
when (io.vme_wr.ack) {
|
||||
when (xrem === 0.U) {
|
||||
when (ycnt === ysize - 1.U) {
|
||||
state := sIdle
|
||||
} .otherwise {
|
||||
state := sWriteCmd
|
||||
when (xsize < xmax) {
|
||||
xlen := xsize
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
xlen := xmax - 1.U
|
||||
xrem := xsize - xmax
|
||||
}
|
||||
}
|
||||
} .elsewhen (xrem < xmax) {
|
||||
state := sWriteCmd
|
||||
xlen := xrem
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
state := sWriteCmd
|
||||
xlen := xmax - 1.U
|
||||
xrem := xrem - xmax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write-to-sram
|
||||
val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
|
||||
val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
|
||||
val no_mask = Wire(Vec(numMemBlock, Bool()))
|
||||
|
||||
wdata_t := DontCare
|
||||
no_mask.foreach { m => m := true.B }
|
||||
|
||||
for (i <- 0 until tensorLength) {
|
||||
val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
|
||||
when (io.tensor.wr.valid) {
|
||||
tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
|
||||
}
|
||||
}
|
||||
|
||||
// read-from-sram
|
||||
val stride = state === sWriteAck &
|
||||
io.vme_wr.ack &
|
||||
xcnt === xlen + 1.U &
|
||||
xrem === 0.U &
|
||||
ycnt =/= ysize - 1.U
|
||||
|
||||
when (state === sIdle) {
|
||||
ycnt := 0.U
|
||||
} .elsewhen (stride) {
|
||||
ycnt := ycnt + 1.U
|
||||
}
|
||||
|
||||
when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
|
||||
tag := 0.U
|
||||
} .elsewhen (io.vme_wr.data.fire()) {
|
||||
tag := tag + 1.U
|
||||
}
|
||||
|
||||
when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
|
||||
set := 0.U
|
||||
} .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
|
||||
set := set + 1.U
|
||||
}
|
||||
|
||||
val raddr_cur = Reg(UInt(tp.memAddrBits.W))
|
||||
val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
|
||||
when (state === sIdle) {
|
||||
raddr_cur := dec.sram_offset
|
||||
raddr_nxt := dec.sram_offset
|
||||
} .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
|
||||
raddr_cur := raddr_cur + 1.U
|
||||
} .elsewhen (stride) {
|
||||
raddr_cur := raddr_nxt + dec.xsize
|
||||
raddr_nxt := raddr_nxt + dec.xsize
|
||||
}
|
||||
|
||||
val tread = Seq.tabulate(tensorLength) { i => i.U ->
|
||||
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
|
||||
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
|
||||
|
||||
// write-to-dram
|
||||
when (state === sIdle) {
|
||||
waddr_cur := io.baddr + dec.dram_offset
|
||||
waddr_nxt := io.baddr + dec.dram_offset
|
||||
} .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
|
||||
waddr_cur := waddr_cur + xmax_bytes
|
||||
} .elsewhen (stride) {
|
||||
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
|
||||
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
|
||||
}
|
||||
|
||||
io.vme_wr.cmd.valid := state === sWriteCmd
|
||||
io.vme_wr.cmd.bits.addr := waddr_cur
|
||||
io.vme_wr.cmd.bits.len := xlen
|
||||
|
||||
io.vme_wr.data.valid := state === sWriteData
|
||||
io.vme_wr.data.bits := mdata(tag)
|
||||
|
||||
when (state === sWriteCmd) {
|
||||
xcnt := 0.U
|
||||
} .elsewhen (io.vme_wr.data.fire()) {
|
||||
xcnt := xcnt + 1.U
|
||||
}
|
||||
|
||||
// disable external read-from-sram requests
|
||||
io.tensor.tieoffRead()
|
||||
|
||||
// done
|
||||
io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U
|
||||
|
||||
// debug
|
||||
if (debug) {
|
||||
when (io.vme_wr.cmd.fire()) {
|
||||
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
|
||||
}
|
||||
when (io.vme_wr.data.fire()) {
|
||||
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
|
||||
}
|
||||
when (io.vme_wr.ack) {
|
||||
printf("[TensorStore] ack\n")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,304 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** TensorParams.
|
||||
*
|
||||
* This Bundle derives parameters for each tensorType, including inputs (inp),
|
||||
* weights (wgt), biases (acc), and outputs (out). This is used to avoid
|
||||
* doing the same boring calculations over and over again.
|
||||
*/
|
||||
class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
|
||||
val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
|
||||
|
||||
require (tensorType == "inp" || tensorType == "wgt"
|
||||
|| tensorType == "acc" || tensorType == "out", errorMsg)
|
||||
|
||||
val (tensorLength, tensorWidth, tensorElemBits) =
|
||||
if (tensorType == "inp")
|
||||
(p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits)
|
||||
else if (tensorType == "wgt")
|
||||
(p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
|
||||
else if (tensorType == "acc")
|
||||
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
|
||||
else
|
||||
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
|
||||
|
||||
val memBlockBits = p(ShellKey).memParams.dataBits
|
||||
val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits
|
||||
|
||||
val memDepth =
|
||||
if (tensorType == "inp")
|
||||
p(CoreKey).inpMemDepth
|
||||
else if (tensorType == "wgt")
|
||||
p(CoreKey).wgtMemDepth
|
||||
else if (tensorType == "acc")
|
||||
p(CoreKey).accMemDepth
|
||||
else
|
||||
p(CoreKey).outMemDepth
|
||||
|
||||
val memAddrBits = log2Ceil(memDepth)
|
||||
}
|
||||
|
||||
/** TensorMaster.
|
||||
*
|
||||
* This interface issue read and write tensor-requests to scratchpads. For example,
|
||||
* The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
|
||||
* biases (acc), and outputs (out).
|
||||
*
|
||||
*/
|
||||
class TensorMaster(tensorType: String = "none")
|
||||
(implicit p: Parameters) extends TensorParams(tensorType) {
|
||||
val rd = new Bundle {
|
||||
val idx = ValidIO(UInt(memAddrBits.W))
|
||||
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
|
||||
}
|
||||
val wr = ValidIO(new Bundle {
|
||||
val idx = UInt(memAddrBits.W)
|
||||
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
|
||||
})
|
||||
def tieoffRead() {
|
||||
rd.idx.valid := false.B
|
||||
rd.idx.bits := 0.U
|
||||
}
|
||||
def tieoffWrite() {
|
||||
wr.valid := false.B
|
||||
wr.bits.idx := 0.U
|
||||
wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
|
||||
}
|
||||
override def cloneType =
|
||||
new TensorMaster(tensorType).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** TensorClient.
|
||||
*
|
||||
* This interface receives read and write tensor-requests to scratchpads. For example,
|
||||
* The TensorLoad unit uses this interface for receiving read and write requests from
|
||||
* the TensorGemm unit.
|
||||
*/
|
||||
class TensorClient(tensorType: String = "none")
|
||||
(implicit p: Parameters) extends TensorParams(tensorType) {
|
||||
val rd = new Bundle {
|
||||
val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
|
||||
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
|
||||
}
|
||||
val wr = Flipped(ValidIO(new Bundle {
|
||||
val idx = UInt(memAddrBits.W)
|
||||
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
|
||||
}))
|
||||
def tieoffRead() {
|
||||
rd.data.valid := false.B
|
||||
rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
|
||||
}
|
||||
override def cloneType =
|
||||
new TensorClient(tensorType).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** TensorMasterData.
|
||||
*
|
||||
* This interface is only used for datapath only purposes and the direction convention
|
||||
* is based on the TensorMaster interface, which means this is an input. This interface
|
||||
* is used on datapath only module such MatrixVectorCore or AluVector.
|
||||
*/
|
||||
class TensorMasterData(tensorType: String = "none")
|
||||
(implicit p: Parameters) extends TensorParams(tensorType) {
|
||||
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
|
||||
override def cloneType =
|
||||
new TensorMasterData(tensorType).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** TensorClientData.
|
||||
*
|
||||
* This interface is only used for datapath only purposes and the direction convention
|
||||
* is based on the TensorClient interface, which means this is an output. This interface
|
||||
* is used on datapath only module such MatrixVectorCore or AluVector.
|
||||
*/
|
||||
class TensorClientData(tensorType: String = "none")
|
||||
(implicit p: Parameters) extends TensorParams(tensorType) {
|
||||
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
|
||||
override def cloneType =
|
||||
new TensorClientData(tensorType).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
|
||||
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
|
||||
val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
|
||||
require (padType == "YPad0" || padType == "YPad1"
|
||||
|| padType == "XPad0" || padType == "XPad1", errorMsg)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
})
|
||||
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
|
||||
val xmax = Reg(chiselTypeOf(dec.xsize))
|
||||
val ymax = Reg(chiselTypeOf(dec.ypad_0))
|
||||
val xcnt = Reg(chiselTypeOf(dec.xsize))
|
||||
val ycnt = Reg(chiselTypeOf(dec.ypad_0))
|
||||
|
||||
val xval =
|
||||
if (padType == "YPad0" || padType == "YPad1")
|
||||
((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U
|
||||
else if (padType == "XPad0")
|
||||
(dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U
|
||||
else
|
||||
(dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U
|
||||
|
||||
val yval =
|
||||
if (padType == "YPad0")
|
||||
Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U)
|
||||
else if (padType == "YPad1")
|
||||
Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U)
|
||||
else
|
||||
0.U
|
||||
|
||||
val sIdle :: sActive :: Nil = Enum(2)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.start) {
|
||||
state := sActive
|
||||
}
|
||||
}
|
||||
is (sActive) {
|
||||
when (ycnt === ymax && xcnt === xmax) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
xmax := xval
|
||||
ymax := yval
|
||||
}
|
||||
|
||||
when (state === sIdle || xcnt === xmax) {
|
||||
xcnt := 0.U
|
||||
} .elsewhen (state === sActive) {
|
||||
xcnt := xcnt + 1.U
|
||||
}
|
||||
|
||||
when (state === sIdle || ymax === 0.U) {
|
||||
ycnt := 0.U
|
||||
} .elsewhen (state === sActive && xcnt === xmax) {
|
||||
ycnt := ycnt + 1.U
|
||||
}
|
||||
|
||||
io.done := state === sActive & ycnt === ymax & xcnt === xmax
|
||||
}
|
||||
|
||||
/** TensorDataCtrl. Data controller for TensorLoad. */
|
||||
class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
|
||||
val mp = p(ShellKey).memParams
|
||||
val io = IO(new Bundle {
|
||||
val start = Input(Bool())
|
||||
val done = Output(Bool())
|
||||
val inst = Input(UInt(INST_BITS.W))
|
||||
val baddr = Input(UInt(mp.addrBits.W))
|
||||
val xinit = Input(Bool())
|
||||
val xupdate = Input(Bool())
|
||||
val yupdate = Input(Bool())
|
||||
val stride = Output(Bool())
|
||||
val split = Output(Bool())
|
||||
val commit = Output(Bool())
|
||||
val addr = Output(UInt(mp.addrBits.W))
|
||||
val len = Output(UInt(mp.lenBits.W))
|
||||
})
|
||||
|
||||
val dec = io.inst.asTypeOf(new MemDecode)
|
||||
|
||||
val caddr = Reg(UInt(mp.addrBits.W))
|
||||
val baddr = Reg(UInt(mp.addrBits.W))
|
||||
|
||||
val len = Reg(UInt(mp.lenBits.W))
|
||||
|
||||
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
|
||||
val xcnt = Reg(UInt(mp.lenBits.W))
|
||||
val xrem = Reg(chiselTypeOf(dec.xsize))
|
||||
val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
|
||||
val xmax = (1 << mp.lenBits).U
|
||||
val ycnt = Reg(chiselTypeOf(dec.ysize))
|
||||
|
||||
val stride = xcnt === len &
|
||||
xrem === 0.U &
|
||||
ycnt =/= dec.ysize - 1.U
|
||||
|
||||
val split = xcnt === len & xrem =/= 0.U
|
||||
|
||||
when (io.start || (io.xupdate && stride)) {
|
||||
when (xsize < xmax) {
|
||||
len := xsize
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
len := xmax - 1.U
|
||||
xrem := xsize - xmax
|
||||
}
|
||||
} .elsewhen (io.xupdate && split) {
|
||||
when (xrem < xmax) {
|
||||
len := xrem
|
||||
xrem := 0.U
|
||||
} .otherwise {
|
||||
len := xmax - 1.U
|
||||
xrem := xrem - xmax
|
||||
}
|
||||
}
|
||||
|
||||
when (io.xinit) {
|
||||
xcnt := 0.U
|
||||
} .elsewhen (io.xupdate) {
|
||||
xcnt := xcnt + 1.U
|
||||
}
|
||||
|
||||
when (io.start) {
|
||||
ycnt := 0.U
|
||||
} .elsewhen (io.yupdate && stride) {
|
||||
ycnt := ycnt + 1.U
|
||||
}
|
||||
|
||||
when (io.start) {
|
||||
caddr := io.baddr + dec.dram_offset
|
||||
baddr := io.baddr + dec.dram_offset
|
||||
} .elsewhen (io.yupdate) {
|
||||
when (split) {
|
||||
caddr := caddr + xmax_bytes
|
||||
} .elsewhen (stride) {
|
||||
caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
|
||||
baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
|
||||
}
|
||||
}
|
||||
|
||||
io.stride := stride
|
||||
io.split := split
|
||||
io.commit := xcnt === len
|
||||
io.addr := caddr
|
||||
io.len := len
|
||||
io.done := xcnt === len &
|
||||
xrem === 0.U &
|
||||
ycnt === dec.ysize - 1.U
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta
|
||||
|
||||
/** This trick makes ISAConstants globally available */
|
||||
package object core extends vta.core.ISAConstants
|
|
@ -21,6 +21,9 @@ package vta.dpi
|
|||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
import vta.shell._
|
||||
|
||||
/** Host DPI parameters */
|
||||
trait VTAHostDPIParams {
|
||||
|
@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
|
|||
})
|
||||
setResource("/verilog/VTAHostDPI.v")
|
||||
}
|
||||
|
||||
/** Host DPI to AXI Converter.
|
||||
*
|
||||
* Convert Host DPI to AXI for VTAShell
|
||||
*/
|
||||
|
||||
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val dpi = new VTAHostDPIClient
|
||||
val axi = new AXILiteMaster(p(ShellKey).hostParams)
|
||||
})
|
||||
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
|
||||
val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
|
||||
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.dpi.req.valid) {
|
||||
when (io.dpi.req.opcode) {
|
||||
state := sWriteAddress
|
||||
} .otherwise {
|
||||
state := sReadAddress
|
||||
}
|
||||
}
|
||||
}
|
||||
is (sReadAddress) {
|
||||
when (io.axi.ar.ready) {
|
||||
state := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.axi.r.valid) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
is (sWriteAddress) {
|
||||
when (io.axi.aw.ready) {
|
||||
state := sWriteData
|
||||
}
|
||||
}
|
||||
is (sWriteData) {
|
||||
when (io.axi.w.ready) {
|
||||
state := sWriteResponse
|
||||
}
|
||||
}
|
||||
is (sWriteResponse) {
|
||||
when (io.axi.b.valid) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle && io.dpi.req.valid) {
|
||||
addr := io.dpi.req.addr
|
||||
data := io.dpi.req.value
|
||||
}
|
||||
|
||||
io.axi.aw.valid := state === sWriteAddress
|
||||
io.axi.aw.bits.addr := addr
|
||||
io.axi.w.valid := state === sWriteData
|
||||
io.axi.w.bits.data := data
|
||||
io.axi.w.bits.strb := "h_f".U
|
||||
io.axi.b.ready := state === sWriteResponse
|
||||
|
||||
io.axi.ar.valid := state === sReadAddress
|
||||
io.axi.ar.bits.addr := addr
|
||||
io.axi.r.ready := state === sReadData
|
||||
|
||||
io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready)
|
||||
io.dpi.resp.valid := io.axi.r.valid
|
||||
io.dpi.resp.bits := io.axi.r.bits.data
|
||||
|
||||
if (debug) {
|
||||
when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
|
||||
when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
|
||||
when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
|
||||
when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,9 @@ package vta.dpi
|
|||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
import vta.shell._
|
||||
|
||||
/** Memory DPI parameters */
|
||||
trait VTAMemDPIParams {
|
||||
|
@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
|
|||
})
|
||||
setResource("/verilog/VTAMemDPI.v")
|
||||
}
|
||||
|
||||
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val dpi = new VTAMemDPIMaster
|
||||
val axi = new AXIClient(p(ShellKey).memParams)
|
||||
})
|
||||
val opcode = RegInit(false.B)
|
||||
val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
|
||||
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
|
||||
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
|
||||
val state = RegInit(sIdle)
|
||||
|
||||
switch (state) {
|
||||
is (sIdle) {
|
||||
when (io.axi.ar.valid) {
|
||||
state := sReadAddress
|
||||
} .elsewhen (io.axi.aw.valid) {
|
||||
state := sWriteAddress
|
||||
}
|
||||
}
|
||||
is (sReadAddress) {
|
||||
when (io.axi.ar.valid) {
|
||||
state := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
is (sWriteAddress) {
|
||||
when (io.axi.aw.valid) {
|
||||
state := sWriteData
|
||||
}
|
||||
}
|
||||
is (sWriteData) {
|
||||
when (io.axi.w.valid && io.axi.w.bits.last) {
|
||||
state := sWriteResponse
|
||||
}
|
||||
}
|
||||
is (sWriteResponse) {
|
||||
when (io.axi.b.ready) {
|
||||
state := sIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (state === sIdle) {
|
||||
when (io.axi.ar.valid) {
|
||||
opcode := false.B
|
||||
len := io.axi.ar.bits.len
|
||||
addr := io.axi.ar.bits.addr
|
||||
} .elsewhen (io.axi.aw.valid) {
|
||||
opcode := true.B
|
||||
len := io.axi.aw.bits.len
|
||||
addr := io.axi.aw.bits.addr
|
||||
}
|
||||
} .elsewhen (state === sReadData) {
|
||||
when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
|
||||
len := len - 1.U
|
||||
}
|
||||
}
|
||||
|
||||
io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid)
|
||||
io.dpi.req.opcode := opcode
|
||||
io.dpi.req.len := len
|
||||
io.dpi.req.addr := addr
|
||||
|
||||
io.axi.ar.ready := state === sReadAddress
|
||||
io.axi.aw.ready := state === sWriteAddress
|
||||
|
||||
io.axi.r.valid := state === sReadData & io.dpi.rd.valid
|
||||
io.axi.r.bits.data := io.dpi.rd.bits
|
||||
io.axi.r.bits.last := len === 0.U
|
||||
io.axi.r.bits.resp := 0.U
|
||||
io.axi.r.bits.user := 0.U
|
||||
io.axi.r.bits.id := 0.U
|
||||
io.dpi.rd.ready := state === sReadData & io.axi.r.ready
|
||||
|
||||
io.dpi.wr.valid := state === sWriteData & io.axi.w.valid
|
||||
io.dpi.wr.bits := io.axi.w.bits.data
|
||||
io.axi.w.ready := state === sWriteData
|
||||
|
||||
io.axi.b.valid := state === sWriteResponse
|
||||
io.axi.b.bits.resp := 0.U
|
||||
io.axi.b.bits.user := 0.U
|
||||
io.axi.b.bits.id := 0.U
|
||||
|
||||
if (debug) {
|
||||
when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
|
||||
when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
|
||||
when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
|
||||
when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.interface.axi
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.genericbundle._
|
||||
|
||||
case class AXIParams(
|
||||
addrBits: Int = 32,
|
||||
dataBits: Int = 64
|
||||
)
|
||||
{
|
||||
require (addrBits > 0)
|
||||
require (dataBits >= 8 && dataBits % 2 == 0)
|
||||
|
||||
val idBits = 1
|
||||
val userBits = 1
|
||||
val strbBits = dataBits/8
|
||||
val lenBits = 8
|
||||
val sizeBits = 3
|
||||
val burstBits = 2
|
||||
val lockBits = 2
|
||||
val cacheBits = 4
|
||||
val protBits = 3
|
||||
val qosBits = 4
|
||||
val regionBits = 4
|
||||
val respBits = 2
|
||||
val sizeConst = log2Ceil(dataBits/8)
|
||||
val idConst = 0
|
||||
val userConst = 0
|
||||
val burstConst = 1
|
||||
val lockConst = 0
|
||||
val cacheConst = 3
|
||||
val protConst = 0
|
||||
val qosConst = 0
|
||||
val regionConst = 0
|
||||
}
|
||||
|
||||
abstract class AXIBase(params: AXIParams)
|
||||
extends GenericParameterizedBundle(params)
|
||||
|
||||
// AXILite
|
||||
|
||||
class AXILiteAddress(params: AXIParams) extends AXIBase(params) {
|
||||
val addr = UInt(params.addrBits.W)
|
||||
}
|
||||
|
||||
class AXILiteWriteData(params: AXIParams) extends AXIBase(params) {
|
||||
val data = UInt(params.dataBits.W)
|
||||
val strb = UInt(params.strbBits.W)
|
||||
}
|
||||
|
||||
class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) {
|
||||
val resp = UInt(params.respBits.W)
|
||||
}
|
||||
|
||||
class AXILiteReadData(params: AXIParams) extends AXIBase(params) {
|
||||
val data = UInt(params.dataBits.W)
|
||||
val resp = UInt(params.respBits.W)
|
||||
}
|
||||
|
||||
class AXILiteMaster(params: AXIParams) extends AXIBase(params) {
|
||||
val aw = Decoupled(new AXILiteAddress(params))
|
||||
val w = Decoupled(new AXILiteWriteData(params))
|
||||
val b = Flipped(Decoupled(new AXILiteWriteResponse(params)))
|
||||
val ar = Decoupled(new AXILiteAddress(params))
|
||||
val r = Flipped(Decoupled(new AXILiteReadData(params)))
|
||||
|
||||
def tieoff() {
|
||||
aw.valid := false.B
|
||||
aw.bits.addr := 0.U
|
||||
w.valid := false.B
|
||||
w.bits.data := 0.U
|
||||
w.bits.strb := 0.U
|
||||
b.ready := false.B
|
||||
ar.valid := false.B
|
||||
ar.bits.addr := 0.U
|
||||
r.ready := false.B
|
||||
}
|
||||
}
|
||||
|
||||
class AXILiteClient(params: AXIParams) extends AXIBase(params) {
|
||||
val aw = Flipped(Decoupled(new AXILiteAddress(params)))
|
||||
val w = Flipped(Decoupled(new AXILiteWriteData(params)))
|
||||
val b = Decoupled(new AXILiteWriteResponse(params))
|
||||
val ar = Flipped(Decoupled(new AXILiteAddress(params)))
|
||||
val r = Decoupled(new AXILiteReadData(params))
|
||||
|
||||
def tieoff() {
|
||||
aw.ready := false.B
|
||||
w.ready := false.B
|
||||
b.valid := false.B
|
||||
b.bits.resp := 0.U
|
||||
ar.ready := false.B
|
||||
r.valid := false.B
|
||||
r.bits.resp := 0.U
|
||||
r.bits.data := 0.U
|
||||
}
|
||||
}
|
||||
|
||||
// AXI extends AXILite
|
||||
|
||||
class AXIAddress(params: AXIParams) extends AXILiteAddress(params) {
|
||||
val id = UInt(params.idBits.W)
|
||||
val user = UInt(params.userBits.W)
|
||||
val len = UInt(params.lenBits.W)
|
||||
val size = UInt(params.sizeBits.W)
|
||||
val burst = UInt(params.burstBits.W)
|
||||
val lock = UInt(params.lockBits.W)
|
||||
val cache = UInt(params.cacheBits.W)
|
||||
val prot = UInt(params.protBits.W)
|
||||
val qos = UInt(params.qosBits.W)
|
||||
val region = UInt(params.regionBits.W)
|
||||
}
|
||||
|
||||
class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) {
|
||||
val last = Bool()
|
||||
val id = UInt(params.idBits.W)
|
||||
val user = UInt(params.userBits.W)
|
||||
}
|
||||
|
||||
class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) {
|
||||
val id = UInt(params.idBits.W)
|
||||
val user = UInt(params.userBits.W)
|
||||
}
|
||||
|
||||
class AXIReadData(params: AXIParams) extends AXILiteReadData(params) {
|
||||
val last = Bool()
|
||||
val id = UInt(params.idBits.W)
|
||||
val user = UInt(params.userBits.W)
|
||||
}
|
||||
|
||||
class AXIMaster(params: AXIParams) extends AXIBase(params) {
|
||||
val aw = Decoupled(new AXIAddress(params))
|
||||
val w = Decoupled(new AXIWriteData(params))
|
||||
val b = Flipped(Decoupled(new AXIWriteResponse(params)))
|
||||
val ar = Decoupled(new AXIAddress(params))
|
||||
val r = Flipped(Decoupled(new AXIReadData(params)))
|
||||
|
||||
def tieoff() {
|
||||
aw.valid := false.B
|
||||
aw.bits.addr := 0.U
|
||||
aw.bits.id := 0.U
|
||||
aw.bits.user := 0.U
|
||||
aw.bits.len := 0.U
|
||||
aw.bits.size := 0.U
|
||||
aw.bits.burst := 0.U
|
||||
aw.bits.lock := 0.U
|
||||
aw.bits.cache := 0.U
|
||||
aw.bits.prot := 0.U
|
||||
aw.bits.qos := 0.U
|
||||
aw.bits.region := 0.U
|
||||
w.valid := false.B
|
||||
w.bits.data := 0.U
|
||||
w.bits.strb := 0.U
|
||||
w.bits.last := false.B
|
||||
w.bits.id := 0.U
|
||||
w.bits.user := 0.U
|
||||
b.ready := false.B
|
||||
ar.valid := false.B
|
||||
ar.bits.addr := 0.U
|
||||
ar.bits.id := 0.U
|
||||
ar.bits.user := 0.U
|
||||
ar.bits.len := 0.U
|
||||
ar.bits.size := 0.U
|
||||
ar.bits.burst := 0.U
|
||||
ar.bits.lock := 0.U
|
||||
ar.bits.cache := 0.U
|
||||
ar.bits.prot := 0.U
|
||||
ar.bits.qos := 0.U
|
||||
ar.bits.region := 0.U
|
||||
r.ready := false.B
|
||||
}
|
||||
|
||||
def setConst() {
|
||||
aw.bits.user := params.userConst.U
|
||||
aw.bits.burst := params.burstConst.U
|
||||
aw.bits.lock := params.lockConst.U
|
||||
aw.bits.cache := params.cacheConst.U
|
||||
aw.bits.prot := params.protConst.U
|
||||
aw.bits.qos := params.qosConst.U
|
||||
aw.bits.region := params.regionConst.U
|
||||
aw.bits.size := params.sizeConst.U
|
||||
aw.bits.id := params.idConst.U
|
||||
w.bits.id := params.idConst.U
|
||||
w.bits.user := params.userConst.U
|
||||
w.bits.strb := Fill(params.strbBits, true.B)
|
||||
ar.bits.user := params.userConst.U
|
||||
ar.bits.burst := params.burstConst.U
|
||||
ar.bits.lock := params.lockConst.U
|
||||
ar.bits.cache := params.cacheConst.U
|
||||
ar.bits.prot := params.protConst.U
|
||||
ar.bits.qos := params.qosConst.U
|
||||
ar.bits.region := params.regionConst.U
|
||||
ar.bits.size := params.sizeConst.U
|
||||
ar.bits.id := params.idConst.U
|
||||
}
|
||||
}
|
||||
|
||||
class AXIClient(params: AXIParams) extends AXIBase(params) {
|
||||
val aw = Flipped(Decoupled(new AXIAddress(params)))
|
||||
val w = Flipped(Decoupled(new AXIWriteData(params)))
|
||||
val b = Decoupled(new AXIWriteResponse(params))
|
||||
val ar = Flipped(Decoupled(new AXIAddress(params)))
|
||||
val r = Decoupled(new AXIReadData(params))
|
||||
|
||||
def tieoff() {
|
||||
aw.ready := false.B
|
||||
w.ready := false.B
|
||||
b.valid := false.B
|
||||
b.bits.resp := 0.U
|
||||
b.bits.user := 0.U
|
||||
b.bits.id := 0.U
|
||||
ar.ready := false.B
|
||||
r.valid := false.B
|
||||
r.bits.resp := 0.U
|
||||
r.bits.data := 0.U
|
||||
r.bits.user := 0.U
|
||||
r.bits.last := false.B
|
||||
r.bits.id := 0.U
|
||||
}
|
||||
}
|
||||
|
||||
// XilinxAXILiteClient and XilinxAXIMaster bundles are needed
|
||||
// for wrapper purposes, because the package RTL tool in Xilinx Vivado
|
||||
// only allows certain name formats
|
||||
|
||||
class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) {
|
||||
val AWVALID = Input(Bool())
|
||||
val AWREADY = Output(Bool())
|
||||
val AWADDR = Input(UInt(params.addrBits.W))
|
||||
val WVALID = Input(Bool())
|
||||
val WREADY = Output(Bool())
|
||||
val WDATA = Input(UInt(params.dataBits.W))
|
||||
val WSTRB = Input(UInt(params.strbBits.W))
|
||||
val BVALID = Output(Bool())
|
||||
val BREADY = Input(Bool())
|
||||
val BRESP = Output(UInt(params.respBits.W))
|
||||
val ARVALID = Input(Bool())
|
||||
val ARREADY = Output(Bool())
|
||||
val ARADDR = Input(UInt(params.addrBits.W))
|
||||
val RVALID = Output(Bool())
|
||||
val RREADY = Input(Bool())
|
||||
val RDATA = Output(UInt(params.dataBits.W))
|
||||
val RRESP = Output(UInt(params.respBits.W))
|
||||
}
|
||||
|
||||
class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) {
|
||||
val AWVALID = Output(Bool())
|
||||
val AWREADY = Input(Bool())
|
||||
val AWADDR = Output(UInt(params.addrBits.W))
|
||||
val AWID = Output(UInt(params.idBits.W))
|
||||
val AWUSER = Output(UInt(params.userBits.W))
|
||||
val AWLEN = Output(UInt(params.lenBits.W))
|
||||
val AWSIZE = Output(UInt(params.sizeBits.W))
|
||||
val AWBURST = Output(UInt(params.burstBits.W))
|
||||
val AWLOCK = Output(UInt(params.lockBits.W))
|
||||
val AWCACHE = Output(UInt(params.cacheBits.W))
|
||||
val AWPROT = Output(UInt(params.protBits.W))
|
||||
val AWQOS = Output(UInt(params.qosBits.W))
|
||||
val AWREGION = Output(UInt(params.regionBits.W))
|
||||
val WVALID = Output(Bool())
|
||||
val WREADY = Input(Bool())
|
||||
val WDATA = Output(UInt(params.dataBits.W))
|
||||
val WSTRB = Output(UInt(params.strbBits.W))
|
||||
val WLAST = Output(Bool())
|
||||
val WID = Output(UInt(params.idBits.W))
|
||||
val WUSER = Output(UInt(params.userBits.W))
|
||||
val BVALID = Input(Bool())
|
||||
val BREADY = Output(Bool())
|
||||
val BRESP = Input(UInt(params.respBits.W))
|
||||
val BID = Input(UInt(params.idBits.W))
|
||||
val BUSER = Input(UInt(params.userBits.W))
|
||||
val ARVALID = Output(Bool())
|
||||
val ARREADY = Input(Bool())
|
||||
val ARADDR = Output(UInt(params.addrBits.W))
|
||||
val ARID = Output(UInt(params.idBits.W))
|
||||
val ARUSER = Output(UInt(params.userBits.W))
|
||||
val ARLEN = Output(UInt(params.lenBits.W))
|
||||
val ARSIZE = Output(UInt(params.sizeBits.W))
|
||||
val ARBURST = Output(UInt(params.burstBits.W))
|
||||
val ARLOCK = Output(UInt(params.lockBits.W))
|
||||
val ARCACHE = Output(UInt(params.cacheBits.W))
|
||||
val ARPROT = Output(UInt(params.protBits.W))
|
||||
val ARQOS = Output(UInt(params.qosBits.W))
|
||||
val ARREGION = Output(UInt(params.regionBits.W))
|
||||
val RVALID = Input(Bool())
|
||||
val RREADY = Output(Bool())
|
||||
val RDATA = Input(UInt(params.dataBits.W))
|
||||
val RRESP = Input(UInt(params.respBits.W))
|
||||
val RLAST = Input(Bool())
|
||||
val RID = Input(UInt(params.idBits.W))
|
||||
val RUSER = Input(UInt(params.userBits.W))
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
|
||||
/** PynqConfig. Shell configuration for Pynq */
|
||||
class PynqConfig extends Config((site, here, up) => {
|
||||
case ShellKey => ShellParams(
|
||||
hostParams = AXIParams(
|
||||
addrBits = 16,
|
||||
dataBits = 32),
|
||||
memParams = AXIParams(
|
||||
addrBits = 32,
|
||||
dataBits = 64),
|
||||
vcrParams = VCRParams(),
|
||||
vmeParams = VMEParams())
|
||||
})
|
||||
|
||||
/** F1Config. Shell configuration for F1 */
|
||||
class F1Config extends Config((site, here, up) => {
|
||||
case ShellKey => ShellParams(
|
||||
hostParams = AXIParams(
|
||||
addrBits = 16,
|
||||
dataBits = 32),
|
||||
memParams = AXIParams(
|
||||
addrBits = 64,
|
||||
dataBits = 64),
|
||||
vcrParams = VCRParams(),
|
||||
vmeParams = VMEParams())
|
||||
})
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
import vta.shell._
|
||||
import vta.dpi._
|
||||
|
||||
/** VTAHost.
|
||||
*
|
||||
* This module translate the DPI protocol into AXI. This is a simulation only
|
||||
* module and used to test host-to-VTA communication. This module should be updated
|
||||
* for testing hosts using a different bus protocol, other than AXI.
|
||||
*/
|
||||
class VTAHost(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val axi = new AXILiteMaster(p(ShellKey).hostParams)
|
||||
})
|
||||
val host_dpi = Module(new VTAHostDPI)
|
||||
val host_axi = Module(new VTAHostDPIToAXI)
|
||||
host_dpi.io.reset := reset
|
||||
host_dpi.io.clock := clock
|
||||
host_axi.io.dpi <> host_dpi.io.dpi
|
||||
io.axi <> host_axi.io.axi
|
||||
}
|
||||
|
||||
/** VTAMem.
|
||||
*
|
||||
* This module translate the DPI protocol into AXI. This is a simulation only
|
||||
* module and used to test VTA-to-memory communication. This module should be updated
|
||||
* for testing memories using a different bus protocol, other than AXI.
|
||||
*/
|
||||
class VTAMem(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val axi = new AXIClient(p(ShellKey).memParams)
|
||||
})
|
||||
val mem_dpi = Module(new VTAMemDPI)
|
||||
val mem_axi = Module(new VTAMemDPIToAXI)
|
||||
mem_dpi.io.reset := reset
|
||||
mem_dpi.io.clock := clock
|
||||
mem_dpi.io.dpi <> mem_axi.io.dpi
|
||||
mem_axi.io.axi <> io.axi
|
||||
}
|
||||
|
||||
/** SimShell.
|
||||
*
|
||||
* The simulation shell instantiate a host and memory simulation modules and it is
|
||||
* intended to be connected to the VTAShell.
|
||||
*/
|
||||
class SimShell(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val mem = new AXIClient(p(ShellKey).memParams)
|
||||
val host = new AXILiteMaster(p(ShellKey).hostParams)
|
||||
})
|
||||
val host = Module(new VTAHost)
|
||||
val mem = Module(new VTAMem)
|
||||
io.mem <> mem.io.axi
|
||||
io.host <> host.io.axi
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.util.genericbundle._
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.mutable.LinkedHashMap
|
||||
import vta.interface.axi._
|
||||
|
||||
/** VCR parameters.
|
||||
*
|
||||
* These parameters are used on VCR interfaces and modules.
|
||||
*/
|
||||
case class VCRParams()
|
||||
{
|
||||
val nValsReg: Int = 1
|
||||
val nPtrsReg: Int = 6
|
||||
val regBits: Int = 32
|
||||
val nCtrlReg: Int = 4
|
||||
val ctrlBaseAddr: Int = 0
|
||||
|
||||
require (nValsReg > 0)
|
||||
require (nPtrsReg > 0)
|
||||
}
|
||||
|
||||
/** VCRBase. Parametrize base class. */
|
||||
abstract class VCRBase(implicit p: Parameters)
|
||||
extends GenericParameterizedBundle(p)
|
||||
|
||||
/** VCRMaster.
|
||||
*
|
||||
* This is the master interface used by VCR in the VTAShell to control
|
||||
* the Core unit.
|
||||
*/
|
||||
class VCRMaster(implicit p: Parameters) extends VCRBase {
|
||||
val vp = p(ShellKey).vcrParams
|
||||
val mp = p(ShellKey).memParams
|
||||
val launch = Output(Bool())
|
||||
val finish = Input(Bool())
|
||||
val irq = Output(Bool())
|
||||
val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
|
||||
val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W)))
|
||||
}
|
||||
|
||||
/** VCRClient.
|
||||
*
|
||||
* This is the client interface used by the Core module to communicate
|
||||
* to the VCR in the VTAShell.
|
||||
*/
|
||||
class VCRClient(implicit p: Parameters) extends VCRBase {
|
||||
val vp = p(ShellKey).vcrParams
|
||||
val mp = p(ShellKey).memParams
|
||||
val launch = Input(Bool())
|
||||
val finish = Output(Bool())
|
||||
val irq = Input(Bool())
|
||||
val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
|
||||
val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W)))
|
||||
}
|
||||
|
||||
/** VTA Control Registers (VCR).
|
||||
*
|
||||
* This unit provides control registers (32 and 64 bits) to be used by a control'
|
||||
* unit, typically a host processor. These registers are read-only by the core
|
||||
* at the moment but this will likely change once we add support to general purpose
|
||||
* registers that could be used as event counters by the Core unit.
|
||||
*/
|
||||
class VCR(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle{
|
||||
val host = new AXILiteClient(p(ShellKey).hostParams)
|
||||
val vcr = new VCRMaster
|
||||
})
|
||||
|
||||
val vp = p(ShellKey).vcrParams
|
||||
val mp = p(ShellKey).memParams
|
||||
val hp = p(ShellKey).hostParams
|
||||
|
||||
// Write control (AW, W, B)
|
||||
val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
|
||||
val wdata = io.host.w.bits.data
|
||||
val wstrb = io.host.w.bits.strb
|
||||
val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
|
||||
val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
|
||||
val wstate = RegInit(sWriteAddress)
|
||||
switch (wstate) {
|
||||
is (sWriteAddress) {
|
||||
when (io.host.aw.valid) {
|
||||
wstate := sWriteData
|
||||
}
|
||||
}
|
||||
is (sWriteData) {
|
||||
when (io.host.w.valid) {
|
||||
wstate := sWriteResponse
|
||||
}
|
||||
}
|
||||
is (sWriteResponse) {
|
||||
when (io.host.b.ready) {
|
||||
wstate := sWriteAddress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
|
||||
|
||||
io.host.aw.ready := wstate === sWriteAddress
|
||||
io.host.w.ready := wstate === sWriteData
|
||||
io.host.b.valid := wstate === sWriteResponse
|
||||
io.host.b.bits.resp := "h_0".U
|
||||
|
||||
// read control (AR, R)
|
||||
val sReadAddress :: sReadData :: Nil = Enum(2)
|
||||
val rstate = RegInit(sReadAddress)
|
||||
|
||||
switch (rstate) {
|
||||
is (sReadAddress) {
|
||||
when (io.host.ar.valid) {
|
||||
rstate := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.host.r.ready) {
|
||||
rstate := sReadAddress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
io.host.ar.ready := rstate === sReadAddress
|
||||
io.host.r.valid := rstate === sReadData
|
||||
|
||||
val nPtrsReg = vp.nPtrsReg
|
||||
val nValsReg = vp.nValsReg
|
||||
val regBits = vp.regBits
|
||||
val ptrsBits = mp.addrBits
|
||||
val nCtrlReg = vp.nCtrlReg
|
||||
val rStride = regBits/8
|
||||
val pStride = ptrsBits/8
|
||||
val ctrlBaseAddr = vp.ctrlBaseAddr
|
||||
val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
|
||||
val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
|
||||
|
||||
val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
|
||||
val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
|
||||
|
||||
val ptrsAddr = new ListBuffer[Int]()
|
||||
for (i <- 0 until nPtrsReg) {
|
||||
ptrsAddr += i*pStride + ptrsBaseAddr
|
||||
if (ptrsBits == 64) {
|
||||
ptrsAddr += i*pStride + rStride + ptrsBaseAddr
|
||||
}
|
||||
}
|
||||
|
||||
// AP register
|
||||
val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
|
||||
|
||||
// ap start
|
||||
when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
|
||||
c0(0) := true.B
|
||||
} .elsewhen (io.vcr.finish) {
|
||||
c0(0) := false.B
|
||||
}
|
||||
|
||||
// ap done = finish
|
||||
when (io.vcr.finish) {
|
||||
c0(1) := true.B
|
||||
} .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) {
|
||||
c0(1) := false.B
|
||||
}
|
||||
|
||||
val c1 = 0.U
|
||||
val c2 = 0.U
|
||||
val c3 = 0.U
|
||||
|
||||
val ctrlRegList = List(c0, c1, c2, c3)
|
||||
|
||||
io.vcr.launch := c0(0)
|
||||
|
||||
// interrupts not supported atm
|
||||
io.vcr.irq := false.B
|
||||
|
||||
// Write pointer and value registers
|
||||
val pvAddr = valsAddr ++ ptrsAddr
|
||||
val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
|
||||
val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
|
||||
val pvRegList = new ListBuffer[UInt]()
|
||||
|
||||
for (i <- 0 until pvNumReg) {
|
||||
when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
|
||||
pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
|
||||
}
|
||||
pvRegList += pvReg(i)
|
||||
}
|
||||
|
||||
for (i <- 0 until nValsReg) {
|
||||
io.vcr.vals(i) := pvReg(i)
|
||||
}
|
||||
|
||||
for (i <- 0 until nPtrsReg) {
|
||||
if (ptrsBits == 64) {
|
||||
io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2))
|
||||
} else {
|
||||
io.vcr.ptrs(i) := pvReg(nValsReg + i)
|
||||
}
|
||||
}
|
||||
|
||||
// Read pointer and value registers
|
||||
val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr
|
||||
val mapRegList = ctrlRegList ++ pvRegList
|
||||
|
||||
val rdata = RegInit(0.U(regBits.W))
|
||||
val rmap = LinkedHashMap[Int,UInt]()
|
||||
|
||||
val totalReg = mapRegList.length
|
||||
for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
|
||||
|
||||
val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) }
|
||||
|
||||
when (io.host.ar.fire()) {
|
||||
rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v)
|
||||
}
|
||||
|
||||
io.host.r.bits.resp := 0.U
|
||||
io.host.r.bits.data := rdata
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import vta.util.config._
|
||||
import vta.util.genericbundle._
|
||||
import vta.interface.axi._
|
||||
|
||||
/** VME parameters.
|
||||
*
|
||||
* These parameters are used on VME interfaces and modules.
|
||||
*/
|
||||
case class VMEParams() {
|
||||
val nReadClients: Int = 5
|
||||
val nWriteClients: Int = 1
|
||||
require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
|
||||
require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
|
||||
}
|
||||
|
||||
/** VMEBase. Parametrize base class. */
|
||||
abstract class VMEBase(implicit p: Parameters)
|
||||
extends GenericParameterizedBundle(p)
|
||||
|
||||
/** VMECmd.
|
||||
*
|
||||
* This interface is used for creating write and read requests to memory.
|
||||
*/
|
||||
class VMECmd(implicit p: Parameters) extends VMEBase {
|
||||
val addrBits = p(ShellKey).memParams.addrBits
|
||||
val lenBits = p(ShellKey).memParams.lenBits
|
||||
val addr = UInt(addrBits.W)
|
||||
val len = UInt(lenBits.W)
|
||||
}
|
||||
|
||||
/** VMEReadMaster.
|
||||
*
|
||||
* This interface is used by modules inside the core to generate read requests
|
||||
* and receive responses from VME.
|
||||
*/
|
||||
class VMEReadMaster(implicit p: Parameters) extends Bundle {
|
||||
val dataBits = p(ShellKey).memParams.dataBits
|
||||
val cmd = Decoupled(new VMECmd)
|
||||
val data = Flipped(Decoupled(UInt(dataBits.W)))
|
||||
override def cloneType =
|
||||
new VMEReadMaster().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** VMEReadClient.
|
||||
*
|
||||
* This interface is used by the VME to receive read requests and generate
|
||||
* responses to modules inside the core.
|
||||
*/
|
||||
class VMEReadClient(implicit p: Parameters) extends Bundle {
|
||||
val dataBits = p(ShellKey).memParams.dataBits
|
||||
val cmd = Flipped(Decoupled(new VMECmd))
|
||||
val data = Decoupled(UInt(dataBits.W))
|
||||
override def cloneType =
|
||||
new VMEReadClient().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** VMEWriteMaster.
|
||||
*
|
||||
* This interface is used by modules inside the core to generate write requests
|
||||
* to the VME.
|
||||
*/
|
||||
class VMEWriteMaster(implicit p: Parameters) extends Bundle {
|
||||
val dataBits = p(ShellKey).memParams.dataBits
|
||||
val cmd = Decoupled(new VMECmd)
|
||||
val data = Decoupled(UInt(dataBits.W))
|
||||
val ack = Input(Bool())
|
||||
override def cloneType =
|
||||
new VMEWriteMaster().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** VMEWriteClient.
|
||||
*
|
||||
* This interface is used by the VME to handle write requests from modules inside
|
||||
* the core.
|
||||
*/
|
||||
class VMEWriteClient(implicit p: Parameters) extends Bundle {
|
||||
val dataBits = p(ShellKey).memParams.dataBits
|
||||
val cmd = Flipped(Decoupled(new VMECmd))
|
||||
val data = Flipped(Decoupled(UInt(dataBits.W)))
|
||||
val ack = Output(Bool())
|
||||
override def cloneType =
|
||||
new VMEWriteClient().asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
/** VMEMaster.
|
||||
*
|
||||
* Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
|
||||
* interfaces.
|
||||
*/
|
||||
class VMEMaster(implicit p: Parameters) extends Bundle {
|
||||
val nRd = p(ShellKey).vmeParams.nReadClients
|
||||
val nWr = p(ShellKey).vmeParams.nWriteClients
|
||||
val rd = Vec(nRd, new VMEReadMaster)
|
||||
val wr = Vec(nWr, new VMEWriteMaster)
|
||||
}
|
||||
|
||||
/** VMEClient.
|
||||
*
|
||||
* Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
|
||||
* interfaces.
|
||||
*/
|
||||
class VMEClient(implicit p: Parameters) extends Bundle {
|
||||
val nRd = p(ShellKey).vmeParams.nReadClients
|
||||
val nWr = p(ShellKey).vmeParams.nWriteClients
|
||||
val rd = Vec(nRd, new VMEReadClient)
|
||||
val wr = Vec(nWr, new VMEWriteClient)
|
||||
}
|
||||
|
||||
/** VTA Memory Engine (VME).
|
||||
*
|
||||
* This unit multiplexes the memory controller interface for the Core. Currently,
|
||||
* it supports single-writer and multiple-reader mode and it is also based on AXI.
|
||||
*/
|
||||
class VME(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val mem = new AXIMaster(p(ShellKey).memParams)
|
||||
val vme = new VMEClient
|
||||
})
|
||||
|
||||
val nReadClients = p(ShellKey).vmeParams.nReadClients
|
||||
val rd_arb = Module(new Arbiter(new VMECmd, nReadClients))
|
||||
val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire())
|
||||
|
||||
for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd }
|
||||
|
||||
val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
|
||||
val rstate = RegInit(sReadIdle)
|
||||
|
||||
switch (rstate) {
|
||||
is (sReadIdle) {
|
||||
when (rd_arb.io.out.valid) {
|
||||
rstate := sReadAddr
|
||||
}
|
||||
}
|
||||
is (sReadAddr) {
|
||||
when (io.mem.ar.ready) {
|
||||
rstate := sReadData
|
||||
}
|
||||
}
|
||||
is (sReadData) {
|
||||
when (io.mem.r.fire() && io.mem.r.bits.last) {
|
||||
rstate := sReadIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4)
|
||||
val wstate = RegInit(sWriteIdle)
|
||||
val addrBits = p(ShellKey).memParams.addrBits
|
||||
val lenBits = p(ShellKey).memParams.lenBits
|
||||
val wr_cnt = RegInit(0.U(lenBits.W))
|
||||
|
||||
when (wstate === sWriteIdle) {
|
||||
wr_cnt := 0.U
|
||||
} .elsewhen (io.mem.w.fire()) {
|
||||
wr_cnt := wr_cnt + 1.U
|
||||
}
|
||||
|
||||
switch (wstate) {
|
||||
is (sWriteIdle) {
|
||||
when (io.vme.wr(0).cmd.valid) {
|
||||
wstate := sWriteAddr
|
||||
}
|
||||
}
|
||||
is (sWriteAddr) {
|
||||
when (io.mem.aw.ready) {
|
||||
wstate := sWriteData
|
||||
}
|
||||
}
|
||||
is (sWriteData) {
|
||||
when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
|
||||
wstate := sWriteResp
|
||||
}
|
||||
}
|
||||
is (sWriteResp) {
|
||||
when (io.mem.b.valid) {
|
||||
wstate := sWriteIdle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registers storing read/write cmds
|
||||
|
||||
val rd_len = RegInit(0.U(lenBits.W))
|
||||
val wr_len = RegInit(0.U(lenBits.W))
|
||||
val rd_addr = RegInit(0.U(addrBits.W))
|
||||
val wr_addr = RegInit(0.U(addrBits.W))
|
||||
|
||||
when (rd_arb.io.out.fire()) {
|
||||
rd_len := rd_arb.io.out.bits.len
|
||||
rd_addr := rd_arb.io.out.bits.addr
|
||||
}
|
||||
|
||||
when (io.vme.wr(0).cmd.fire()) {
|
||||
wr_len := io.vme.wr(0).cmd.bits.len
|
||||
wr_addr := io.vme.wr(0).cmd.bits.addr
|
||||
}
|
||||
|
||||
// rd arb
|
||||
rd_arb.io.out.ready := rstate === sReadIdle
|
||||
|
||||
// vme
|
||||
for (i <- 0 until nReadClients) {
|
||||
io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
|
||||
io.vme.rd(i).data.bits := io.mem.r.bits.data
|
||||
}
|
||||
|
||||
io.vme.wr(0).cmd.ready := wstate === sWriteIdle
|
||||
io.vme.wr(0).ack := io.mem.b.fire()
|
||||
io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
|
||||
|
||||
// mem
|
||||
io.mem.aw.valid := wstate === sWriteAddr
|
||||
io.mem.aw.bits.addr := wr_addr
|
||||
io.mem.aw.bits.len := wr_len
|
||||
|
||||
io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
|
||||
io.mem.w.bits.data := io.vme.wr(0).data.bits
|
||||
io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
|
||||
|
||||
io.mem.b.ready := wstate === sWriteResp
|
||||
|
||||
io.mem.ar.valid := rstate === sReadAddr
|
||||
io.mem.ar.bits.addr := rd_addr
|
||||
io.mem.ar.bits.len := rd_len
|
||||
|
||||
io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
|
||||
|
||||
// AXI constants - statically defined
|
||||
io.mem.setConst()
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
import vta.core._
|
||||
|
||||
/** Shell parameters. */
|
||||
case class ShellParams(
|
||||
hostParams: AXIParams,
|
||||
memParams: AXIParams,
|
||||
vcrParams: VCRParams,
|
||||
vmeParams: VMEParams
|
||||
)
|
||||
|
||||
case object ShellKey extends Field[ShellParams]
|
||||
|
||||
/** VTAShell.
|
||||
*
|
||||
* The VTAShell is based on a VME, VCR and core. This creates a complete VTA
|
||||
* system that can be used for simulation or real hardware.
|
||||
*/
|
||||
class VTAShell(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle{
|
||||
val host = new AXILiteClient(p(ShellKey).hostParams)
|
||||
val mem = new AXIMaster(p(ShellKey).memParams)
|
||||
})
|
||||
|
||||
val vcr = Module(new VCR)
|
||||
val vme = Module(new VME)
|
||||
val core = Module(new Core)
|
||||
|
||||
core.io.vcr <> vcr.io.vcr
|
||||
vme.io.vme <> core.io.vme
|
||||
|
||||
vcr.io.host <> io.host
|
||||
io.mem <> vme.io.mem
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.shell
|
||||
|
||||
import chisel3._
|
||||
import chisel3.experimental.{RawModule, withClockAndReset}
|
||||
import vta.util.config._
|
||||
import vta.interface.axi._
|
||||
|
||||
/** XilinxShell.
|
||||
*
|
||||
* This is a wrapper shell mostly used to match Xilinx convention naming,
|
||||
* therefore we can pack VTA as an IP for IPI based flows.
|
||||
*/
|
||||
class XilinxShell(implicit p: Parameters) extends RawModule {
|
||||
|
||||
val hp = p(ShellKey).hostParams
|
||||
val mp = p(ShellKey).memParams
|
||||
|
||||
val ap_clk = IO(Input(Clock()))
|
||||
val ap_rst_n = IO(Input(Bool()))
|
||||
val m_axi_gmem = IO(new XilinxAXIMaster(mp))
|
||||
val s_axi_control = IO(new XilinxAXILiteClient(hp))
|
||||
|
||||
val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
|
||||
|
||||
// memory
|
||||
m_axi_gmem.AWVALID := shell.io.mem.aw.valid
|
||||
shell.io.mem.aw.ready := m_axi_gmem.AWREADY
|
||||
m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr
|
||||
m_axi_gmem.AWID := shell.io.mem.aw.bits.id
|
||||
m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user
|
||||
m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len
|
||||
m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size
|
||||
m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst
|
||||
m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock
|
||||
m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache
|
||||
m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot
|
||||
m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos
|
||||
m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region
|
||||
|
||||
m_axi_gmem.WVALID := shell.io.mem.w.valid
|
||||
shell.io.mem.w.ready := m_axi_gmem.WREADY
|
||||
m_axi_gmem.WDATA := shell.io.mem.w.bits.data
|
||||
m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb
|
||||
m_axi_gmem.WLAST := shell.io.mem.w.bits.last
|
||||
m_axi_gmem.WID := shell.io.mem.w.bits.id
|
||||
m_axi_gmem.WUSER := shell.io.mem.w.bits.user
|
||||
|
||||
shell.io.mem.b.valid := m_axi_gmem.BVALID
|
||||
m_axi_gmem.BREADY := shell.io.mem.b.valid
|
||||
shell.io.mem.b.bits.resp := m_axi_gmem.BRESP
|
||||
shell.io.mem.b.bits.id := m_axi_gmem.BID
|
||||
shell.io.mem.b.bits.user := m_axi_gmem.BUSER
|
||||
|
||||
m_axi_gmem.ARVALID := shell.io.mem.ar.valid
|
||||
shell.io.mem.ar.ready := m_axi_gmem.ARREADY
|
||||
m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr
|
||||
m_axi_gmem.ARID := shell.io.mem.ar.bits.id
|
||||
m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user
|
||||
m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len
|
||||
m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size
|
||||
m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst
|
||||
m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock
|
||||
m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache
|
||||
m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot
|
||||
m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos
|
||||
m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region
|
||||
|
||||
shell.io.mem.r.valid := m_axi_gmem.RVALID
|
||||
m_axi_gmem.RREADY := shell.io.mem.r.ready
|
||||
shell.io.mem.r.bits.data := m_axi_gmem.RDATA
|
||||
shell.io.mem.r.bits.resp := m_axi_gmem.RRESP
|
||||
shell.io.mem.r.bits.last := m_axi_gmem.RLAST
|
||||
shell.io.mem.r.bits.id := m_axi_gmem.RID
|
||||
shell.io.mem.r.bits.user := m_axi_gmem.RUSER
|
||||
|
||||
// host
|
||||
shell.io.host.aw.valid := s_axi_control.AWVALID
|
||||
s_axi_control.AWREADY := shell.io.host.aw.ready
|
||||
shell.io.host.aw.bits.addr := s_axi_control.AWADDR
|
||||
|
||||
shell.io.host.w.valid := s_axi_control.WVALID
|
||||
s_axi_control.WREADY := shell.io.host.w.ready
|
||||
shell.io.host.w.bits.data := s_axi_control.WDATA
|
||||
shell.io.host.w.bits.strb := s_axi_control.WSTRB
|
||||
|
||||
s_axi_control.BVALID := shell.io.host.b.valid
|
||||
shell.io.host.b.ready := s_axi_control.BREADY
|
||||
s_axi_control.BRESP := shell.io.host.b.bits.resp
|
||||
|
||||
shell.io.host.ar.valid := s_axi_control.ARVALID
|
||||
s_axi_control.ARREADY := shell.io.host.ar.ready
|
||||
shell.io.host.ar.bits.addr := s_axi_control.ARADDR
|
||||
|
||||
s_axi_control.RVALID := shell.io.host.r.valid
|
||||
shell.io.host.r.ready := s_axi_control.RREADY
|
||||
s_axi_control.RDATA := shell.io.host.r.bits.data
|
||||
s_axi_control.RRESP := shell.io.host.r.bits.resp
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.test
|
||||
|
||||
import chisel3._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
|
||||
/** Test. This generates a testbench file for simulation */
|
||||
class Test(implicit p: Parameters) extends Module {
|
||||
val io = IO(new Bundle {})
|
||||
val sim_shell = Module(new SimShell)
|
||||
val vta_shell = Module(new VTAShell)
|
||||
vta_shell.io.host <> sim_shell.io.host
|
||||
sim_shell.io.mem <> vta_shell.io.mem
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.util.config
|
||||
|
||||
// taken from https://github.com/vta.roject/rocket-chip
|
||||
|
||||
abstract class Field[T] private (val default: Option[T])
|
||||
{
|
||||
def this() = this(None)
|
||||
def this(default: T) = this(Some(default))
|
||||
}
|
||||
|
||||
abstract class View {
|
||||
final def apply[T](pname: Field[T]): T = apply(pname, this)
|
||||
final def apply[T](pname: Field[T], site: View): T = {
|
||||
val out = find(pname, site)
|
||||
require (out.isDefined, s"Key ${pname} is not defined in Parameters")
|
||||
out.get
|
||||
}
|
||||
|
||||
final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
|
||||
final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
|
||||
|
||||
protected[config] def find[T](pname: Field[T], site: View): Option[T]
|
||||
}
|
||||
|
||||
abstract class Parameters extends View {
|
||||
final def ++ (x: Parameters): Parameters =
|
||||
new ChainParameters(this, x)
|
||||
|
||||
final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
|
||||
Parameters(f) ++ this
|
||||
|
||||
final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
|
||||
Parameters((_,_,_) => f) ++ this
|
||||
|
||||
final def alterMap(m: Map[Any,Any]): Parameters =
|
||||
new MapParameters(m) ++ this
|
||||
|
||||
protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
|
||||
protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
|
||||
}
|
||||
|
||||
object Parameters {
|
||||
def empty: Parameters = new EmptyParameters
|
||||
def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
|
||||
}
|
||||
|
||||
class Config(p: Parameters) extends Parameters {
|
||||
def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
|
||||
|
||||
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
|
||||
override def toString = this.getClass.getSimpleName
|
||||
def toInstance = this
|
||||
}
|
||||
|
||||
// Internal implementation:
|
||||
|
||||
private class TerminalView extends View {
|
||||
def find[T](pname: Field[T], site: View): Option[T] = pname.default
|
||||
}
|
||||
|
||||
private class ChainView(head: Parameters, tail: View) extends View {
|
||||
def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname)
|
||||
}
|
||||
|
||||
private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
|
||||
def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
|
||||
}
|
||||
|
||||
private class EmptyParameters extends Parameters {
|
||||
def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
|
||||
}
|
||||
|
||||
private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
|
||||
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
|
||||
val g = f(site, this, tail)
|
||||
if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
|
||||
}
|
||||
}
|
||||
|
||||
private class MapParameters(map: Map[Any, Any]) extends Parameters {
|
||||
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
|
||||
val g = map.get(pname)
|
||||
if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta.util.genericbundle
|
||||
|
||||
// taken from https://github.com/vta.roject/rocket-chip
|
||||
|
||||
import chisel3._
|
||||
|
||||
abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
|
||||
{
|
||||
override def cloneType = {
|
||||
try {
|
||||
this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
|
||||
} catch {
|
||||
case e: java.lang.IllegalArgumentException =>
|
||||
throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
|
||||
this.getClass + ", probably because " + this.getClass +
|
||||
"() takes more than one argument. Consider overriding " +
|
||||
"cloneType() on " + this.getClass, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package vta
|
||||
|
||||
import chisel3._
|
||||
import vta.util.config._
|
||||
import vta.shell._
|
||||
import vta.core._
|
||||
import vta.test._
|
||||
|
||||
/** VTA.
|
||||
*
|
||||
* This file contains all the configurations supported by VTA.
|
||||
* These configurations are built in a mix/match form based on core
|
||||
* and shell configurations.
|
||||
*/
|
||||
|
||||
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
|
||||
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
|
||||
|
||||
object DefaultPynqConfig extends App {
|
||||
implicit val p: Parameters = new DefaultPynqConfig
|
||||
chisel3.Driver.execute(args, () => new XilinxShell)
|
||||
}
|
||||
|
||||
object DefaultF1Config extends App {
|
||||
implicit val p: Parameters = new DefaultF1Config
|
||||
chisel3.Driver.execute(args, () => new XilinxShell)
|
||||
}
|
||||
|
||||
object TestDefaultF1Config extends App {
|
||||
implicit val p: Parameters = new DefaultF1Config
|
||||
chisel3.Driver.execute(args, () => new Test)
|
||||
}
|
|
@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle,
|
|||
_mem_dpi = mem_dpi;
|
||||
}
|
||||
|
||||
|
||||
// Override Verilator finish definition
|
||||
// VL_USER_FINISH needs to be defined when compiling Verilator code
|
||||
void vl_finish(const char* filename, int linenum, const char* hier) {
|
||||
Verilated::gotFinish(true);
|
||||
VL_PRINTF("[TSIM] exiting simulation\n");
|
||||
}
|
||||
|
||||
int VTADPISim(uint64_t max_cycles) {
|
||||
uint64_t trace_count = 0;
|
||||
Verilated::flushCall();
|
||||
Verilated::gotFinish(false);
|
||||
|
||||
#if VM_TRACE
|
||||
uint64_t start = 0;
|
||||
|
|
|
@ -53,7 +53,11 @@ extern "C" {
|
|||
typedef void * VTADeviceHandle;
|
||||
|
||||
/*! \brief physical address */
|
||||
#ifdef USE_TSIM
|
||||
typedef uint64_t vta_phy_addr_t;
|
||||
#else
|
||||
typedef uint32_t vta_phy_addr_t;
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief Allocate a device resource handle
|
||||
|
@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle);
|
|||
*
|
||||
* \return 0 if running is successful, 1 if timeout.
|
||||
*/
|
||||
#ifdef USE_TSIM
|
||||
int VTADeviceRun(VTADeviceHandle device,
|
||||
vta_phy_addr_t insn_phy_addr,
|
||||
vta_phy_addr_t uop_phy_addr,
|
||||
vta_phy_addr_t inp_phy_addr,
|
||||
vta_phy_addr_t wgt_phy_addr,
|
||||
vta_phy_addr_t acc_phy_addr,
|
||||
vta_phy_addr_t out_phy_addr,
|
||||
uint32_t insn_count,
|
||||
uint32_t wait_cycles);
|
||||
#else
|
||||
int VTADeviceRun(VTADeviceHandle device,
|
||||
vta_phy_addr_t insn_phy_addr,
|
||||
uint32_t insn_count,
|
||||
uint32_t wait_cycles);
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief Allocates physically contiguous region in memory (limited by MAX_XFER).
|
||||
|
|
|
@ -239,7 +239,7 @@ class Environment(object):
|
|||
"""The target host"""
|
||||
if self.TARGET == "pynq":
|
||||
return "llvm -target=armv7-none-linux-gnueabihf"
|
||||
if self.TARGET == "sim":
|
||||
if self.TARGET == "sim" or self.TARGET == "tsim":
|
||||
return "llvm"
|
||||
raise ValueError("Unknown target %s" % self.TARGET)
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
"""Utilities to start simulator."""
|
||||
import ctypes
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import tvm
|
||||
from ..libinfo import find_libvta
|
||||
|
||||
|
@ -55,5 +57,22 @@ def stats():
|
|||
x = tvm.get_global_func("vta.simulator.profiler_status")()
|
||||
return json.loads(x)
|
||||
|
||||
def tsim_init(hw_lib):
|
||||
"""Init hardware shared library for TSIM
|
||||
|
||||
Parameters
|
||||
------------
|
||||
hw_lib : str
|
||||
Name of hardware shared library
|
||||
"""
|
||||
cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
||||
vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
|
||||
if not hw_lib.endswith(("dylib", "so")):
|
||||
hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
|
||||
lib = os.path.join(vta_build_path, hw_lib)
|
||||
f = tvm.get_global_func("tvm.vta.tsim.init")
|
||||
m = tvm.module.load(lib, "vta-tsim")
|
||||
f(m)
|
||||
|
||||
|
||||
LIBS = _load_lib()
|
||||
|
|
|
@ -31,7 +31,7 @@ def run(run_func):
|
|||
"""
|
||||
env = get_env()
|
||||
|
||||
if env.TARGET == "sim":
|
||||
if env.TARGET in ["sim", "tsim"]:
|
||||
|
||||
# Talk to local RPC if necessary to debug RPC server.
|
||||
# Compile vta on your host with make at the root.
|
||||
|
@ -48,6 +48,7 @@ def run(run_func):
|
|||
# Make sure simulation library exists
|
||||
# If this fails, build vta on host (make)
|
||||
# with TARGET="sim" in the json.config file.
|
||||
if env.TARGET == "sim":
|
||||
assert simulator.enabled()
|
||||
run_func(env, rpc.LocalSession())
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ struct DataBuffer {
|
|||
return data_;
|
||||
}
|
||||
/*! \return Physical address of the data. */
|
||||
uint32_t phy_addr() const {
|
||||
vta_phy_addr_t phy_addr() const {
|
||||
return phy_addr_;
|
||||
}
|
||||
/*!
|
||||
|
@ -113,7 +113,7 @@ struct DataBuffer {
|
|||
/*! \brief The internal data. */
|
||||
void* data_;
|
||||
/*! \brief The physical address of the buffer, excluding header. */
|
||||
uint32_t phy_addr_;
|
||||
vta_phy_addr_t phy_addr_;
|
||||
};
|
||||
|
||||
/*!
|
||||
|
@ -302,7 +302,7 @@ class BaseQueue {
|
|||
return dram_buffer_;
|
||||
}
|
||||
/*! \return Physical address of DRAM. */
|
||||
uint32_t dram_phy_addr() const {
|
||||
vta_phy_addr_t dram_phy_addr() const {
|
||||
return dram_phy_addr_;
|
||||
}
|
||||
/*! \return Whether there is pending information. */
|
||||
|
@ -367,7 +367,7 @@ class BaseQueue {
|
|||
// The buffer in DRAM
|
||||
char* dram_buffer_{nullptr};
|
||||
// Physics address of the buffer
|
||||
uint32_t dram_phy_addr_;
|
||||
vta_phy_addr_t dram_phy_addr_;
|
||||
};
|
||||
|
||||
/*!
|
||||
|
@ -424,7 +424,11 @@ class UopQueue : public BaseQueue {
|
|||
CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_));
|
||||
insn->memory_type = VTA_MEM_ID_UOP;
|
||||
insn->sram_base = sram_begin_;
|
||||
#ifdef USE_TSIM
|
||||
insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes;
|
||||
#else
|
||||
insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_;
|
||||
#endif
|
||||
insn->y_size = 1;
|
||||
insn->x_size = (dram_end_ - dram_begin_);
|
||||
insn->x_stride = (dram_end_ - dram_begin_);
|
||||
|
@ -958,7 +962,11 @@ class CommandQueue {
|
|||
insn->memory_type = dst_memory_type;
|
||||
insn->sram_base = dst_sram_index;
|
||||
DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
|
||||
#ifdef USE_TSIM
|
||||
insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type);
|
||||
#else
|
||||
insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
|
||||
#endif
|
||||
insn->y_size = y_size;
|
||||
insn->x_size = x_size;
|
||||
insn->x_stride = x_stride;
|
||||
|
@ -981,7 +989,11 @@ class CommandQueue {
|
|||
insn->memory_type = src_memory_type;
|
||||
insn->sram_base = src_sram_index;
|
||||
DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
|
||||
#ifdef USE_TSIM
|
||||
insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type);
|
||||
#else
|
||||
insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
|
||||
#endif
|
||||
insn->y_size = y_size;
|
||||
insn->x_size = x_size;
|
||||
insn->x_stride = x_stride;
|
||||
|
@ -1046,11 +1058,24 @@ class CommandQueue {
|
|||
|
||||
// Make sure that we don't exceed contiguous physical memory limits
|
||||
CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
|
||||
#ifdef USE_TSIM
|
||||
int timeout = VTADeviceRun(
|
||||
device_,
|
||||
insn_queue_.dram_phy_addr(),
|
||||
uop_queue_.dram_phy_addr(),
|
||||
inp_phy_addr_,
|
||||
wgt_phy_addr_,
|
||||
acc_phy_addr_,
|
||||
out_phy_addr_,
|
||||
insn_queue_.count(),
|
||||
wait_cycles);
|
||||
#else
|
||||
int timeout = VTADeviceRun(
|
||||
device_,
|
||||
insn_queue_.dram_phy_addr(),
|
||||
insn_queue_.count(),
|
||||
wait_cycles);
|
||||
#endif
|
||||
CHECK_EQ(timeout, 0);
|
||||
// Reset buffers
|
||||
uop_queue_.Reset();
|
||||
|
@ -1125,6 +1150,18 @@ class CommandQueue {
|
|||
ThreadLocal().reset();
|
||||
}
|
||||
|
||||
#ifdef USE_TSIM
|
||||
void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) {
|
||||
switch (type) {
|
||||
case VTA_MEM_ID_INP: inp_phy_addr_ = addr;
|
||||
case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr;
|
||||
case VTA_MEM_ID_ACC: acc_phy_addr_ = addr;
|
||||
case VTA_MEM_ID_OUT: out_phy_addr_ = addr;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
// Push GEMM uop to the command buffer
|
||||
void PushGEMMOp(UopKernel* kernel) {
|
||||
|
@ -1229,6 +1266,16 @@ class CommandQueue {
|
|||
InsnQueue<VTA_MAX_XFER, true, true> insn_queue_;
|
||||
// Device handle
|
||||
VTADeviceHandle device_{nullptr};
|
||||
#ifdef USE_TSIM
|
||||
// Input phy addr
|
||||
vta_phy_addr_t inp_phy_addr_{0};
|
||||
// Weight phy addr
|
||||
vta_phy_addr_t wgt_phy_addr_{0};
|
||||
// Accumulator phy addr
|
||||
vta_phy_addr_t acc_phy_addr_{0};
|
||||
// Output phy addr
|
||||
vta_phy_addr_t out_phy_addr_{0};
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace vta
|
||||
|
@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd,
|
|||
uint32_t y_pad_after,
|
||||
uint32_t dst_sram_index,
|
||||
uint32_t dst_memory_type) {
|
||||
#ifdef USE_TSIM
|
||||
vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr);
|
||||
static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr());
|
||||
#endif
|
||||
static_cast<vta::CommandQueue*>(cmd)->
|
||||
LoadBuffer2D(src_dram_addr, src_elem_offset,
|
||||
x_size, y_size, x_stride,
|
||||
|
@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd,
|
|||
uint32_t x_size,
|
||||
uint32_t y_size,
|
||||
uint32_t x_stride) {
|
||||
#ifdef USE_TSIM
|
||||
vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr);
|
||||
static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr());
|
||||
#endif
|
||||
static_cast<vta::CommandQueue*>(cmd)->
|
||||
StoreBuffer2D(src_sram_index, src_memory_type,
|
||||
dst_dram_addr, dst_elem_offset,
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
#include <vta/driver.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <vta/dpi/module.h>
|
||||
|
||||
namespace vta {
|
||||
namespace tsim {
|
||||
|
||||
using vta::dpi::DPIModuleNode;
|
||||
using tvm::runtime::Module;
|
||||
|
||||
class DPILoader {
|
||||
public:
|
||||
void Init(Module module) {
|
||||
mod_ = module;
|
||||
}
|
||||
|
||||
DPIModuleNode* Get() {
|
||||
return static_cast<DPIModuleNode*>(mod_.operator->());
|
||||
}
|
||||
|
||||
static DPILoader* Global() {
|
||||
static DPILoader inst;
|
||||
return &inst;
|
||||
}
|
||||
|
||||
Module mod_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device() {
|
||||
dpi_ = DPILoader::Global();
|
||||
}
|
||||
|
||||
int Run(vta_phy_addr_t insn_phy_addr,
|
||||
vta_phy_addr_t uop_phy_addr,
|
||||
vta_phy_addr_t inp_phy_addr,
|
||||
vta_phy_addr_t wgt_phy_addr,
|
||||
vta_phy_addr_t acc_phy_addr,
|
||||
vta_phy_addr_t out_phy_addr,
|
||||
uint32_t insn_count,
|
||||
uint32_t wait_cycles) {
|
||||
this->Init();
|
||||
this->Launch(insn_phy_addr,
|
||||
uop_phy_addr,
|
||||
inp_phy_addr,
|
||||
wgt_phy_addr,
|
||||
acc_phy_addr,
|
||||
out_phy_addr,
|
||||
insn_count,
|
||||
wait_cycles);
|
||||
this->WaitForCompletion(wait_cycles);
|
||||
dev_->Finish();
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
dev_ = dpi_->Get();
|
||||
}
|
||||
|
||||
void Launch(vta_phy_addr_t insn_phy_addr,
|
||||
vta_phy_addr_t uop_phy_addr,
|
||||
vta_phy_addr_t inp_phy_addr,
|
||||
vta_phy_addr_t wgt_phy_addr,
|
||||
vta_phy_addr_t acc_phy_addr,
|
||||
vta_phy_addr_t out_phy_addr,
|
||||
uint32_t insn_count,
|
||||
uint32_t wait_cycles) {
|
||||
// launch simulation thread
|
||||
dev_->Launch(wait_cycles);
|
||||
dev_->WriteReg(0x10, insn_count);
|
||||
dev_->WriteReg(0x14, insn_phy_addr);
|
||||
dev_->WriteReg(0x18, insn_phy_addr >> 32);
|
||||
dev_->WriteReg(0x1c, 0);
|
||||
dev_->WriteReg(0x20, uop_phy_addr >> 32);
|
||||
dev_->WriteReg(0x24, 0);
|
||||
dev_->WriteReg(0x28, inp_phy_addr >> 32);
|
||||
dev_->WriteReg(0x2c, 0);
|
||||
dev_->WriteReg(0x30, wgt_phy_addr >> 32);
|
||||
dev_->WriteReg(0x34, 0);
|
||||
dev_->WriteReg(0x38, acc_phy_addr >> 32);
|
||||
dev_->WriteReg(0x3c, 0);
|
||||
dev_->WriteReg(0x40, out_phy_addr >> 32);
|
||||
// start
|
||||
dev_->WriteReg(0x00, 0x1);
|
||||
}
|
||||
|
||||
void WaitForCompletion(uint32_t wait_cycles) {
|
||||
uint32_t i, val;
|
||||
for (i = 0; i < wait_cycles; i++) {
|
||||
val = dev_->ReadReg(0x00);
|
||||
val &= 0x2;
|
||||
if (val == 0x2) break; // finish
|
||||
}
|
||||
}
|
||||
|
||||
DPILoader* dpi_;
|
||||
DPIModuleNode* dev_;
|
||||
};
|
||||
|
||||
using tvm::runtime::TVMRetValue;
|
||||
using tvm::runtime::TVMArgs;
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
Module m = args[0];
|
||||
DPILoader::Global()->Init(m);
|
||||
});
|
||||
|
||||
} // namespace tsim
|
||||
} // namespace vta
|
||||
|
||||
void* VTAMemAlloc(size_t size, int cached) {
|
||||
void *p = malloc(size);
|
||||
return p;
|
||||
}
|
||||
|
||||
void VTAMemFree(void* buf) {
|
||||
free(buf);
|
||||
}
|
||||
|
||||
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
|
||||
return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
|
||||
}
|
||||
|
||||
void VTAFlushCache(vta_phy_addr_t buf, int size) {
|
||||
}
|
||||
|
||||
void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
|
||||
}
|
||||
|
||||
VTADeviceHandle VTADeviceAlloc() {
|
||||
return new vta::tsim::Device();
|
||||
}
|
||||
|
||||
void VTADeviceFree(VTADeviceHandle handle) {
|
||||
delete static_cast<vta::tsim::Device*>(handle);
|
||||
}
|
||||
|
||||
int VTADeviceRun(VTADeviceHandle handle,
|
||||
vta_phy_addr_t insn_phy_addr,
|
||||
vta_phy_addr_t uop_phy_addr,
|
||||
vta_phy_addr_t inp_phy_addr,
|
||||
vta_phy_addr_t wgt_phy_addr,
|
||||
vta_phy_addr_t acc_phy_addr,
|
||||
vta_phy_addr_t out_phy_addr,
|
||||
uint32_t insn_count,
|
||||
uint32_t wait_cycles) {
|
||||
return static_cast<vta::tsim::Device*>(handle)->Run(
|
||||
insn_phy_addr,
|
||||
uop_phy_addr,
|
||||
inp_phy_addr,
|
||||
wgt_phy_addr,
|
||||
acc_phy_addr,
|
||||
out_phy_addr,
|
||||
insn_count,
|
||||
wait_cycles);
|
||||
}
|
|
@ -68,6 +68,10 @@ def test_save_load_out():
|
|||
y_np = x_np.astype(y.dtype)
|
||||
x_nd = tvm.nd.array(x_np, ctx)
|
||||
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
f(x_nd, y_nd)
|
||||
np.testing.assert_equal(y_np, y_nd.asnumpy())
|
||||
|
||||
|
@ -126,6 +130,10 @@ def test_padded_load():
|
|||
:] = x_np
|
||||
x_nd = tvm.nd.array(x_np, ctx)
|
||||
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
f(x_nd, y_nd)
|
||||
np.testing.assert_equal(y_np, y_nd.asnumpy())
|
||||
|
||||
|
@ -197,6 +205,9 @@ def test_gemm():
|
|||
y_np = np.right_shift(y_np, 8)
|
||||
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
if env.TARGET == "sim":
|
||||
simulator.clear_stats()
|
||||
f(x_nd, w_nd, y_nd)
|
||||
|
@ -351,6 +362,10 @@ def test_alu():
|
|||
a_nd = tvm.nd.array(a_np, ctx)
|
||||
res_nd = tvm.nd.array(
|
||||
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
if use_imm:
|
||||
f(a_nd, res_nd)
|
||||
else:
|
||||
|
@ -420,6 +435,10 @@ def test_relu():
|
|||
a_nd = tvm.nd.array(a_np, ctx)
|
||||
res_nd = tvm.nd.array(
|
||||
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
f(a_nd, res_nd)
|
||||
np.testing.assert_equal(res_np, res_nd.asnumpy())
|
||||
|
||||
|
@ -479,6 +498,10 @@ def test_shift_and_scale():
|
|||
a_nd = tvm.nd.array(a_np, ctx)
|
||||
res_nd = tvm.nd.array(
|
||||
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
|
||||
|
||||
if env.TARGET == "tsim":
|
||||
simulator.tsim_init("libvta_hw")
|
||||
|
||||
f(a_nd, res_nd)
|
||||
np.testing.assert_equal(res_np, res_nd.asnumpy())
|
||||
|
||||
|
@ -503,11 +526,12 @@ if __name__ == "__main__":
|
|||
print("Load/store test")
|
||||
test_save_load_out()
|
||||
print("Padded load test")
|
||||
#test_padded_load()
|
||||
test_padded_load()
|
||||
print("GEMM test")
|
||||
test_gemm()
|
||||
test_alu()
|
||||
print("ALU test")
|
||||
test_alu()
|
||||
print("Relu test")
|
||||
test_relu()
|
||||
print("Shift and scale")
|
||||
test_shift_and_scale()
|
||||
|
|
Загрузка…
Ссылка в новой задаче