99
README.md
|
@ -1,73 +1,90 @@
|
|||
## Edge Machine Learning
|
||||
## The Edge Machine Learning library
|
||||
|
||||
This repository provides code for machine learning algorithms for edge devices
|
||||
developed at [Microsoft Research
|
||||
India](https://www.microsoft.com/en-us/research/project/resource-efficient-ml-for-the-edge-and-endpoint-iot-devices/).
|
||||
|
||||
Machine learning models for edge devices need to have a small footprint in
|
||||
terms of storage, prediction latency, and energy. One example of a ubiquitous
|
||||
real-world application where such models are desirable is resource-scarce
|
||||
devices and sensors in the Internet of Things (IoT) setting. Making real-time
|
||||
predictions locally on IoT devices without connecting to the cloud requires
|
||||
models that fit in a few kilobytes.
|
||||
terms of storage, prediction latency, and energy. One instance of where such
|
||||
models are desirable is resource-scarce devices and sensors in the Internet
|
||||
of Things (IoT) setting. Making real-time predictions locally on IoT devices
|
||||
without connecting to the cloud requires models that fit in a few kilobytes.
|
||||
|
||||
This repository contains algorithms that shine in this setting in terms of both model size and compute, namely:
|
||||
### Contents
|
||||
Algorithms that shine in this setting in terms of both model size and compute, namely:
|
||||
- **Bonsai**: Strong and shallow non-linear tree based classifier.
|
||||
- **ProtoNN**: **Proto**type based k-nearest neighbors (k**NN**) classifier.
|
||||
- **EMI-RNN**: Training routine to recover the critical signature from time series data for faster and accurate RNN predictions.
|
||||
- **S-RNN**: A meta-architecture for training RNNs that can be applied to streaming data.
|
||||
- **FastRNN & FastGRNN - FastCells**: **F**ast, **A**ccurate, **S**table and **T**iny (**G**ated) RNN cells.
|
||||
- **SeeDot**: Floating-point to fixed-point quantization tool.
|
||||
- **GesturePod**: Gesture recognition pipeline for microcontrollers.
|
||||
|
||||
These algorithms can train models for classical supervised learning problems
|
||||
with memory requirements that are orders of magnitude lower than other modern
|
||||
ML algorithms. The trained models can be loaded onto edge devices such as IoT
|
||||
devices/sensors, and used to make fast and accurate predictions completely
|
||||
offline.
|
||||
|
||||
The `tf` directory contains code, examples and scripts for all these algorithms
|
||||
in TensorFlow. The `pytorch` directory contains code, examples and scripts for all these algorithms
|
||||
in PyTorch. The `cpp` directory has training and inference code for Bonsai and
|
||||
ProtoNN algorithms in C++. Please see install/run instruction in the Readme
|
||||
pages within these directories. The `applications` directory has code/demonstrations
|
||||
of applications of the EdgeML algorithms. The `Tools/SeeDot` directory has the
|
||||
quantization tool to generate fixed-point inference code.
|
||||
A tool that adapts models trained by above algorithms to be inferred by fixed point arithmetic.
|
||||
- **SeeDot**: Floating-point to fixed-point quantization tool.
|
||||
|
||||
For details, please see our [wiki
|
||||
page](https://github.com/Microsoft/EdgeML/wiki/) and our ICML'17 publications
|
||||
on [Bonsai](docs/publications/Bonsai.pdf) and
|
||||
[ProtoNN](docs/publications/ProtoNN.pdf) algorithms, NeurIPS'18 publications on
|
||||
[EMI-RNN](docs/publications/emi-rnn-nips18.pdf) and
|
||||
[FastGRNN](docs/publications/FastGRNN.pdf), PLDI'19 publication on
|
||||
[SeeDot](docs/publications/SeeDot.pdf), and UIST'19 publication on
|
||||
[GesturePod](docs/publications/GesturePod-UIST19.pdf).
|
||||
Applications demonstrating usecases of these algorithms:
|
||||
- **GesturePod**: Gesture recognition pipeline for microcontrollers.
|
||||
- **MSC-RNN**: Multi-scale cascaded RNN for analyzing Radar data.
|
||||
|
||||
### Organization
|
||||
- The `tf` directory contains the `edgeml_tf` package which specifies these architectures in TensorFlow,
|
||||
and `examples/tf` contains sample training routines for these algorithms.
|
||||
- The `pytorch` directory contains the `edgeml_pytorch` package which specifies these architectures in PyTorch,
|
||||
and `examples/pytorch` contains sample training routines for these algorithms.
|
||||
- The `cpp` directory has training and inference code for Bonsai and ProtoNN algorithms in C++.
|
||||
- The `applications` directory has code/demonstrations of applications of the EdgeML algorithms.
|
||||
- The `tools/SeeDot` directory has the quantization tool to generate fixed-point inference code.
|
||||
|
||||
Please see install/run instructions in the README pages within these directories.
|
||||
|
||||
### Details and project pages
|
||||
For details, please see our
|
||||
[project page](https://microsoft.github.io/EdgeML/),
|
||||
[Microsoft Research page](https://www.microsoft.com/en-us/research/project/resource-efficient-ml-for-the-edge-and-endpoint-iot-devices/),
|
||||
the ICML '17 publications on [Bonsai](/docs/publications/Bonsai.pdf) and
|
||||
[ProtoNN](/docs/publications/ProtoNN.pdf) algorithms,
|
||||
the NeurIPS '18 publications on [EMI-RNN](/docs/publications/emi-rnn-nips18.pdf) and
|
||||
[FastGRNN](/docs/publications/FastGRNN.pdf),
|
||||
the PLDI '19 publication on [SeeDot compiler](/docs/publications/SeeDot.pdf),
|
||||
the UIST '19 publication on [Gesturepod](/docs/publications/ICane-UIST19.pdf),
|
||||
the BuildSys '19 publication on [MSC-RNN](/docs/publications/MSCRNN.pdf),
|
||||
and the NeurIPS '19 publication on [S-RNN](/docs/publications/SRNN.pdf).
|
||||
|
||||
|
||||
Core Contributors:
|
||||
- [Aditya Kusupati](https://adityakusupati.github.io/)
|
||||
- [Ashish Kumar](https://ashishkumar1993.github.io/)
|
||||
- [Chirag Gupta](https://aigen.github.io/)
|
||||
Also checkout the [ELL](https://github.com/Microsoft/ELL) project which can
|
||||
provide optimized binaries for some of the ONNX models trained by this library.
|
||||
|
||||
### Contributors:
|
||||
Code for algorithms, applications and tools contributed by:
|
||||
- [Don Dennis](https://dkdennis.xyz)
|
||||
- [Harsha Vardhan Simhadri](http://harsha-simhadri.org)
|
||||
- [Shishir Patil](https://shishirpatil.github.io/)
|
||||
- [Yash Gaurkar](https://github.com/mr-yamraj/)
|
||||
- [Sridhar Gopinath](http://www.sridhargopinath.in/)
|
||||
- [Chirag Gupta](https://aigen.github.io/)
|
||||
- [Moksh Jain](https://github.com/MJ10)
|
||||
- [Ashish Kumar](https://ashishkumar1993.github.io/)
|
||||
- [Aditya Kusupati](https://adityakusupati.github.io/)
|
||||
- [Chris Lovett](https://github.com/lovettchris)
|
||||
- [Shishir Patil](https://shishirpatil.github.io/)
|
||||
- [Harsha Vardhan Simhadri](http://harsha-simhadri.org)
|
||||
|
||||
We welcome contributions, comments, and criticism. For questions, please [email
|
||||
us](mailto:edgeml@microsoft.com).
|
||||
[Contributors](https://microsoft.github.io/EdgeML/People) to this project. New contributors welcome.
|
||||
|
||||
[People](https://github.com/Microsoft/EdgeML/wiki/People/) who have contributed
|
||||
to this
|
||||
[project](https://www.microsoft.com/en-us/research/project/resource-efficient-ml-for-the-edge-and-endpoint-iot-devices/).
|
||||
Please [email us](mailto:edgeml@microsoft.com) your comments, criticism, and questions.
|
||||
|
||||
If you use the EdgeML library in your projects or publications, please do cite us using the following BibTex:
|
||||
If you use software from this library in your work, please use the BibTex entry below for citation.
|
||||
|
||||
```
|
||||
@software{edgeml01,
|
||||
author = {{Dennis, Don Kurian and Gopinath, Sridhar and Gupta, Chirag and
|
||||
Kumar, Ashish and Kusupati, Aditya and Patil, Shishir G and Simhadri, Harsha Vardhan}},
|
||||
@software{edgeml03,
|
||||
author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Gupta, Chirag and
|
||||
Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and Lovett, Chris
|
||||
and Patil, Shishir G and Simhadri, Harsha Vardhan}},
|
||||
title = {{EdgeML: Machine Learning for resource-constrained edge devices}},
|
||||
url = {https://github.com/Microsoft/EdgeML},
|
||||
version = {0.1},
|
||||
version = {0.3},
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
MSC-RNN - Multi-Scale, Cascaded RNN
|
||||
==========
|
||||
|
||||
MSC-RNN is a new RNN architecture proposed in the paper,
|
||||
[One Size Does Not Fit All: Multi-Scale, Cascaded RNNs for Radar Classification](https://arxiv.org/abs/1909.03082),
|
||||
which won the **Best Paper Runner-Up** Award at *BuildSys 2019*.
|
||||
|
||||
MSC-RNN is created using EMI-RNN and FastGRNN from the EdgeML repository.
|
||||
It comprises of an EMI-FastGRNN for clutter discrimination at a lower tier and a more complex FastGRNN
|
||||
classifier for source classifcation at the upper-tier and is trained using a novel joint-training routine.
|
||||
|
||||
MSC-RNN holistically improves the accuracy and per-class recalls over ML models suitable for radar inferencing.
|
||||
Notably, MSC-RNN outperforms cross-domain handcrafted feature engineering with time-domain deep feature learning,
|
||||
while also being up to ∼3× more efficient than the competitive SVM based solutions.
|
||||
|
||||
# Resources
|
||||
|
||||
**Paper** - [pdf](/docs/publications/MSCRNN.pdf) | [arXiv](https://arxiv.org/pdf/1909.03082.pdf) | [ACM DL](https://dl.acm.org/citation.cfm?id=3360860)
|
||||
|
||||
**Code** - https://github.com/dhruboroy29/MSCRNN
|
||||
|
||||
**Dataset** - https://doi.org/10.5281/zenodo.3451408
|
|
@ -6,7 +6,7 @@ Tensorflow library. This document **does not** seek to be a comprehensive
|
|||
documentation of the EMI-RNN code base.
|
||||
|
||||
For a quick and dirty 'getting started' example please refer to
|
||||
`tf/examples/EMI-RNN` directory.
|
||||
[EMI_RNN](../examples/tf/EMI-RNN) directory.
|
||||
|
||||
![MIML Formulation of Bags and Instances](img/MIML_illustration.png)
|
||||
|
||||
|
@ -49,21 +49,21 @@ sub-instance is set initialized to equal the label of the entire bag.
|
|||
Thus, the EMI-RNN implementation expects the train data to be of shape,
|
||||
|
||||
[Num. of examples, Num. of instances, Num. of timestep, Num. features]
|
||||
|
||||
|
||||
Further, the label information is expected to be one hot encoded and of the
|
||||
shape,
|
||||
|
||||
[Num. of examples, Num. of instances, Num. classes]
|
||||
|
||||
As a concrete end to end example, please refer to
|
||||
`tf/examples/EMI-RNN/fetch_har.py`.
|
||||
[EMI_RNN/fetch_har.py](../examples/tf/EMI-RNN/fetch_har.py).
|
||||
|
||||
## Training
|
||||
|
||||
![An illustration of the parts of the computation graph](img/3PartsGraph.png)
|
||||
|
||||
The EMI-RNN algorithm consists of a graph construction phase and a training
|
||||
phase.
|
||||
phase.
|
||||
|
||||
### Graph Construction
|
||||
|
||||
|
@ -83,7 +83,7 @@ forward computation graphs. All implementations of `EMI_RNN` are expected and
|
|||
assumed to provide an `EMI_RNN.output` attribute - the Tensor/Operation with
|
||||
the forward computation outputs. The following implementations of `EMI_RNN` are
|
||||
provided:
|
||||
- `EMI_LSTM`
|
||||
- `EMI_LSTM`
|
||||
- `EMI_GRU`
|
||||
- `EMI_FastRNN`
|
||||
- `EMI_FastGRNN`
|
||||
|
@ -116,11 +116,11 @@ Train_EMI_RNN:
|
|||
Y: Train labels
|
||||
EMI_Graph: A complete training graph
|
||||
updatePolicy: An update policy that will update the instace lables
|
||||
after each training rounds.
|
||||
after each training rounds.
|
||||
NUM_ROUNDS: Number of rounds of training
|
||||
|
||||
|
||||
curr_Y = Y
|
||||
for round in range(NUM_ROUNDS):
|
||||
for round in range(NUM_ROUNDS):
|
||||
minimize_loss(EMI_graph, X, curr_Y)
|
||||
curr_Y = updatePolicy(EMI_RNN(X))
|
||||
```
|
||||
|
@ -163,7 +163,7 @@ invalid upon the graph being reset internally.
|
|||
|
||||
It is possible to restore a trained model into a session from its checkpoint.
|
||||
`EMI_Driver` exposes an easy to use way of achieving this through
|
||||
`loadSavedGraphToNewSession` method.
|
||||
`loadSavedGraphToNewSession` method.
|
||||
|
||||
To use this method, first construct a new computation graph as you would
|
||||
normally do and setup `EMI_Driver` with this computation graph. Then you can
|
||||
|
@ -181,8 +181,8 @@ matrices is also supported. This is achieved by attaching `tf.assign`
|
|||
operations to all the model tensors. Please have a look at `addAssignOps`
|
||||
method of `DataPipeline`, `EMI_RNN` and `EMI_Trainer` for more information.
|
||||
|
||||
Please refer to `tf/examples/02_emi_lstm_initialization_and_restoring.npy` for
|
||||
example usages.
|
||||
Please refer to [EMI_RNN/02_emi_lstm_initialization_and_restoring.ipynb](../examples/tf/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb)
|
||||
for example usages.
|
||||
|
||||
## Evaluating the trained model
|
||||
|
||||
|
@ -210,4 +210,4 @@ Early prediction is accomplished by defining an early prediction policy method.
|
|||
This method receives the prediction at each step of the learned RNN for a
|
||||
sub-instance as input and is expected to return a predicted class and the
|
||||
0-indexed step at which it made this prediction. Please refer to the
|
||||
`tf/examples/EMI-RNN` for concrete examples of the same.
|
||||
[EMI-RNN](../examples/tf/EMI-RNN) for concrete examples of the same.
|
|
@ -1,11 +1,11 @@
|
|||
# FastRNN and FastGRNN - FastCells
|
||||
|
||||
This document aims to explain and elaborate on specific details of FastCells
|
||||
present as part of `tf/edgeml/graph/rnn.py`. The endpoint use case scripts with
|
||||
3 phase training along with an example notebook are present in `tf/examples/FastCells/`.
|
||||
One can use the endpoint script to test out the RNN architectures on any dataset
|
||||
while specifying budget constraints as part of hyper-parameters in terms of sparsity and rank
|
||||
of weight matrices.
|
||||
This document elaborates on the details of FastCells
|
||||
present in [tf/edgeml_tf/graph/rnn.py](/tf/edgeml_tf/graph/rnn.py). The
|
||||
endpoint use case scripts with 3 phase training along with an example notebook
|
||||
are present in [examples/tf/FastCells](/examples/tf/FastCells). One can use the endpoint script to test
|
||||
out the RNN architectures on any dataset while specifying budget constraints as
|
||||
part of hyper-parameters in terms of sparsity and rank of weight matrices.
|
||||
|
||||
# FastRNN
|
||||
![FastRNN](img/FastRNN.png)
|
||||
|
@ -17,23 +17,23 @@ of weight matrices.
|
|||
|
||||
# Plug and Play Cells
|
||||
|
||||
`FastRNNCell` and `FastGRNNCell` present in `edgeml.graph.rnn` are very similar to
|
||||
Tensorflow's inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell`, and `UGRNNCell` allowing us to
|
||||
replace any of the standard RNN Cell in our architecture with FastCells.
|
||||
One can see the plug and play nature at the endpoint script for FastCells, where the graph
|
||||
building is very similar to LSTM/GRU in Tensorflow.
|
||||
`FastRNNCell` and `FastGRNNCell` present in `edgeml.graph.rnn` are very similar to
|
||||
Tensorflow's inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell`, and `UGRNNCell` allowing us to
|
||||
replace any of the standard RNN Cell in our architecture with FastCells.
|
||||
One can see the plug and play nature at the endpoint script for FastCells, where the graph
|
||||
building is very similar to LSTM/GRU in Tensorflow.
|
||||
|
||||
Script: [Endpoint Script](../examples/FastCells/fastcell_example.py)
|
||||
Script: [Endpoint Script](/examples/tf/FastCells/fastcell_example.py)
|
||||
|
||||
Example Notebook: [iPython Notebook](../examples/FastCells/fastcell_example.ipynb)
|
||||
Example Notebook: [iPython Notebook](/examples/tf/FastCells/fastcell_example.ipynb)
|
||||
|
||||
Cells: [FastRNNCell](../edgeml/graph/rnn.py#L206) and [FastGRNNCell](../edgeml/graph/rnn.py#L31).
|
||||
Cells: [FastRNNCell](/tf/edgeml/graph/rnn.py#L206) and [FastGRNNCell](/tf/edgeml/graph/rnn.py#L31).
|
||||
|
||||
# 3 phase Fast Training
|
||||
|
||||
`FastCells`, similar to `Bonsai` use a 3 phase training routine, to induce the right
|
||||
support and sparsity for the weight matrices. With the low-rank parameterization of weights
|
||||
followed by the 3 phase training, we obtain FastRNN and FastGRNN models which are compact
|
||||
`FastCells`, similar to `Bonsai` use a 3 phase training routine, to induce the right
|
||||
support and sparsity for the weight matrices. With the low-rank parameterization of weights
|
||||
followed by the 3 phase training, we obtain FastRNN and FastGRNN models which are compact
|
||||
and they can be further compressed by using byte quantization without significant loss in accuracy.
|
||||
|
||||
# Compression
|
||||
|
@ -42,16 +42,16 @@ and they can be further compressed by using byte quantization without significan
|
|||
2) Sparsity (S)
|
||||
3) Quantization (Q)
|
||||
|
||||
Low-rank is directly induced into the FastCells during initialization and the training happens with
|
||||
the targetted low-rank versions of the weight matrices. One can use `wRank` and `uRank` parameters
|
||||
Low-rank is directly induced into the FastCells during initialization and the training happens with
|
||||
the targetted low-rank versions of the weight matrices. One can use `wRank` and `uRank` parameters
|
||||
of FastCells to achieve this.
|
||||
|
||||
Sparsity is taken in as hyper-parameter during the 3 phase training into `fastTrainer.py` which at the
|
||||
Sparsity is taken in as hyper-parameter during the 3 phase training into `fastTrainer.py` which at the
|
||||
end spits out a sparse, low-rank model.
|
||||
|
||||
Further compression is achieved by byte Quantization and can be performed using `quantizeFastModels.py`
|
||||
script which is part of `tf/exampled/FastCells/`. This will give model size reduction of up to 4x if 8-bit
|
||||
integers are used. Lastly, to facilitate all integer arithmetic, including the non-linearities, one could
|
||||
use `quantTanh` instead of `tanh` and `quantSigm` instead of `sigmoid` as the non-linearities in the RNN
|
||||
Cells followed by byte quantization. These non-linearities can be set using the appropriate parameters in
|
||||
Further compression is achieved by byte Quantization and can be performed using `quantizeFastModels.py`
|
||||
script which is part of [examples/tf/FastCells](/examples/tf/FastCells). This will give model size reduction of up to 4x if 8-bit
|
||||
integers are used. Lastly, to facilitate all integer arithmetic, including the non-linearities, one could
|
||||
use `quantTanh` instead of `tanh` and `quantSigm` instead of `sigmoid` as the non-linearities in the RNN
|
||||
Cells followed by byte quantization. These non-linearities can be set using the appropriate parameters in
|
||||
the `FastRNNCell` and `FastGRNNCell`
|
До Ширина: | Высота: | Размер: 25 KiB После Ширина: | Высота: | Размер: 25 KiB |
До Ширина: | Высота: | Размер: 13 KiB После Ширина: | Высота: | Размер: 13 KiB |
До Ширина: | Высота: | Размер: 10 KiB После Ширина: | Высота: | Размер: 10 KiB |
До Ширина: | Высота: | Размер: 11 KiB После Ширина: | Высота: | Размер: 11 KiB |
До Ширина: | Высота: | Размер: 4.6 KiB После Ширина: | Высота: | Размер: 4.6 KiB |
До Ширина: | Высота: | Размер: 23 KiB После Ширина: | Высота: | Размер: 23 KiB |
|
@ -4,10 +4,11 @@ This directory includes, example notebook and general execution script of
|
|||
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
|
||||
use-case on the USPS10 public dataset.
|
||||
|
||||
`pytorch_edgeml.graph.bonsai` implements the Bonsai prediction graph in pytorch.
|
||||
`edgeml_pytorch.graph.bonsai` implements the Bonsai prediction graph in pytorch.
|
||||
The three-phase training routine for Bonsai is decoupled from the forward graph
|
||||
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
||||
See `edgeml_pytorch.trainer.bonsaiTrainer` for 3-phase training.
|
||||
|
||||
Note that `bonsai_example.py` assumes that data is in a specific format. It is
|
||||
assumed that train and test data is contained in two files, `train.npy` and
|
|
@ -4,8 +4,8 @@
|
|||
import helpermethods
|
||||
import numpy as np
|
||||
import sys
|
||||
from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from pytorch_edgeml.graph.bonsai import Bonsai
|
||||
from edgeml_pytorch.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml_pytorch.graph.bonsai import Bonsai
|
||||
import torch
|
||||
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
## Training Keyword-Spotting model
|
||||
|
||||
This example demonstrates how to train a FastGRNN-based keyword spotting model based on the Google speech commands dataset,
|
||||
compile it using the ELL compiler and deploy the keyword spotting model on [STM BlueCoin](https://www.st.com/en/evaluation-tools/steval-bcnkt01v1.html).
|
||||
Follow the steps below to featurize data using ELL, train and export an ONNX model using the EdgeML library,
|
||||
and prepare a binary that provides prediction capability using the ELL library.
|
||||
|
||||
### Install ELL
|
||||
[link](https://github.com/microsoft/ELL)
|
||||
|
||||
### Download Google speech commands dataset
|
||||
Download the [dataset](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz) and extract data
|
||||
```
|
||||
mkdir data_speech_commands_v2
|
||||
tar xvzf data_speech_commands_v0.02.tar.gz -C data_speech_commands_v2
|
||||
```
|
||||
|
||||
### Export path to dataset and ELL
|
||||
```
|
||||
export ELL_ROOT=<path to directory were ELL is installed>
|
||||
### export ELL_ROOT=/home/user/ELL
|
||||
export DATASET_PATH=<path to directory were speechcommand dataset is extracted>
|
||||
### export DATASET_PATH=/mnt/../../data_speech_data_v2
|
||||
```
|
||||
|
||||
### Make training list -
|
||||
Use `-max n` to over-ride the default limit on the maximum number of samples from each category including `background`. For low false positive rate, train with a large number of negative `background` examples, say 50000 or 250000.
|
||||
```
|
||||
python3 $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_training_list.py -max 50000 --wav_files $DATASET_PATH
|
||||
```
|
||||
|
||||
### Create an ELL featurizer -
|
||||
```
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_featurizer.py -ws 400 --nfft 512 --iir --log --log_delta 2.220446049250313e-16 --power_spec -fs 32
|
||||
```
|
||||
|
||||
### Compile the ELL featurizer -
|
||||
```
|
||||
python $ELL_ROOT/tools/wrap/wrap.py --model_file featurizer.ell --outdir compiled_featurizer --module_name mfcc
|
||||
cd compiled_featurizer && mkdir build && cd build && cmake .. && make && cd ../..
|
||||
```
|
||||
|
||||
### Pre-Process Dataset:
|
||||
```
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/training_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/validation_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/testing_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore
|
||||
```
|
||||
|
||||
If you have a background noise clips not containing keywords that you want to fuse with your dataset with,
|
||||
place them in a folder `$DATASET_PATH/backgroundNoise` and follow these instructions instead of the ones above.
|
||||
```
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/training_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore --noise_path $DATASET_PATH/backgroundNoise --max_noise_ratio 0.1 --noise_selection 1
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/validation_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore --noise_path $DATASET_PATH/backgroundNoise --max_noise_ratio 0.1 --noise_selection 1
|
||||
python $ELL_ROOT/tools/utilities/pythonlibs/audio/training/make_dataset.py --list_file $DATASET_PATH/testing_list.txt --featurizer compiled_featurizer/mfcc --window_size 98 --shift 98 --multicore --noise_path $DATASET_PATH/backgroundNoise --max_noise_ratio 0.1 --noise_selection 1
|
||||
```
|
||||
|
||||
### Run model training:
|
||||
```
|
||||
python examples/pytorch/FastCells/train_classifier.py \
|
||||
--use_gpu --normalize --rolling --max_rolling_length 235 \
|
||||
-a $DATASET_PATH -c $DATASET_PATH/categories.txt --outdir $MODEL_DIR \
|
||||
--architecture FastGRNNCUDA --num_layers 2 \
|
||||
--epochs 250 --learning_rate 0.005 -bs 128 -hu 128 \
|
||||
--lr_min 0.0005 --lr_scheduler CosineAnnealingLR --lr_peaks 0
|
||||
```
|
||||
Drop the `--rolling` and `--max_rolling_length` options if you are going to run inference on 1 second clips,
|
||||
and do not plan to stream data through the model without resettting.
|
||||
|
||||
### Convert .onnx model to .ell IR
|
||||
```
|
||||
pip install onnx #If you haven't already
|
||||
python $ELL_ROOT/tools/importers/onnx/onnx_import.py output_model/model.onnx
|
||||
```
|
||||
|
||||
|
||||
### Compiling model and featurizer header and binary files for ARM Cortex M4 class devices.
|
||||
|
||||
These commands will use ELL compiler to generate some files of which 4 are required: featurizer.h, featurizer.S, model.h and model.S
|
||||
|
||||
#### For devices with hard FPU, e.g., STM Bluecoin
|
||||
```
|
||||
$ELL_ROOT/build/bin/compile -imap model.ell -cfn Predict -cmn completemodel --bitcode -od . --fuseLinearOps True --header --blas false --optimize true --target custom --numBits 32 --cpu cortex-m4 --triple armv6m-gnueabi --features +vfp4,+d16
|
||||
/usr/lib/llvm-8/bin/opt model.bc -o model.opt.bc -O3
|
||||
/usr/lib/llvm-8/bin/llc model.opt.bc -o model.S -O3 -filetype=asm -mtriple=armv6m-gnueabi -mcpu=cortex-m4 -relocation-model=pic -float-abi=hard -mattr=+vfp4,+d16
|
||||
$ELL_ROOT/build/bin/compile -imap featurizer.ell -cfn Filter -cmn mfcc --bitcode -od . --fuseLinearOps True --header --blas false --optimize true --target custom --numBits 32 --cpu cortex-m4 --triple armv6m-gnueabi --features +vfp4,+d16
|
||||
/usr/lib/llvm-8/bin/opt featurizer.bc -o featurizer.opt.bc -O3
|
||||
/usr/lib/llvm-8/bin/llc featurizer.opt.bc -o featurizer.S -O3 -filetype=asm -mtriple=armv6m-gnueabi -mcpu=cortex-m4 -relocation-model=pic -float-abi=hard -mattr=+vfp4,+d16
|
||||
```
|
||||
|
||||
#### For M4 class devices without hard FPU, e.g., MXchip
|
||||
```
|
||||
$ELL_ROOT/build/bin/compile -imap model.ell -cfn Predict -cmn completemodel --bitcode -od . --fuseLinearOps True --header --blas false --optimize true --target custom --numBits 32 --cpu cortex-m4 --triple armv6m-gnueabi --features +vfp4,+d16,+soft-float
|
||||
/usr/lib/llvm-8/bin/opt model.bc -o model.opt.bc -O3
|
||||
/usr/lib/llvm-8/bin/llc model.opt.bc -o model.S -O3 -filetype=asm -mtriple=armv6m-gnueabi -mcpu=cortex-m4 -relocation-model=pic -float-abi=soft -mattr=+vfp4,+d16
|
||||
$ELL_ROOT/build/bin/compile -imap featurizer.ell -cfn Filter -cmn mfcc --bitcode -od . --fuseLinearOps True --header --blas false --optimize true --target custom --numBits 32 --cpu cortex-m4 --triple armv6m-gnueabi --features +vfp4,+d16,+soft-float
|
||||
/usr/lib/llvm-8/bin/opt featurizer.bc -o featurizer.opt.bc -O3
|
||||
/usr/lib/llvm-8/bin/llc featurizer.opt.bc -o featurizer.S -O3 -filetype=asm -mtriple=armv6m-gnueabi -mcpu=cortex-m4 -relocation-model=pic -float-abi=soft -mattr=+vfp4,+d16
|
||||
```
|
|
@ -0,0 +1,747 @@
|
|||
#!/usr/bin/env python3
|
||||
###################################################################################################
|
||||
#
|
||||
# Project: Embedded Learning Library (ELL)
|
||||
# File: train_classifier.py
|
||||
# Authors: Chris Lovett
|
||||
#
|
||||
# Requires: Python 3.x
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.onnx
|
||||
import random
|
||||
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from training_config import TrainingConfig
|
||||
from edgeml_pytorch.trainer.fastmodel import *
|
||||
|
||||
class KeywordSpotter(nn.Module):
|
||||
""" This baseclass provides the PyTorch Module pattern for defining and training keyword spotters """
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the KeywordSpotter with the following parameters:
|
||||
input_dim - the size of the input audio frame in # samples
|
||||
num_keywords - the number of predictions to come out of the model.
|
||||
"""
|
||||
super(KeywordSpotter, self).__init__()
|
||||
|
||||
self.training = False
|
||||
self.tracking = False
|
||||
|
||||
self.init_hidden()
|
||||
|
||||
def name(self):
|
||||
return "KeywordSpotter"
|
||||
|
||||
def init_hidden(self):
|
||||
""" Clear any hidden state """
|
||||
pass
|
||||
|
||||
def forward(self, input):
|
||||
""" Perform the forward processing of the given input and return the prediction """
|
||||
raise Exception("need to implement the forward method")
|
||||
|
||||
def export(self, name, device):
|
||||
""" Export the model to the ONNX file format """
|
||||
self.init_hidden()
|
||||
self.tracking = True
|
||||
dummy_input = Variable(torch.randn(1, 1, self.input_dim))
|
||||
if device:
|
||||
dummy_input = dummy_input.to(device)
|
||||
torch.onnx.export(self, dummy_input, name, verbose=True)
|
||||
self.tracking = False
|
||||
|
||||
def batch_accuracy(self, scores, labels):
|
||||
""" Compute the training accuracy of the results of a single mini-batch """
|
||||
batch_size = scores.shape[0]
|
||||
passed = 0
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
expected = labels[i]
|
||||
actual = scores[i].argmax()
|
||||
results += [int(actual)]
|
||||
if expected == actual:
|
||||
passed += 1
|
||||
return (float(passed) * 100.0 / float(batch_size), passed, results)
|
||||
|
||||
def configure_optimizer(self, options):
|
||||
initial_rate = options.learning_rate
|
||||
oo = options.optimizer_options
|
||||
|
||||
if options.optimizer == "Adadelta":
|
||||
optimizer = optim.Adadelta(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
rho=oo.rho, eps=oo.eps)
|
||||
elif options.optimizer == "Adagrad":
|
||||
optimizer = optim.Adagrad(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
lr_decay=oo.lr_decay)
|
||||
elif options.optimizer == "Adam":
|
||||
optimizer = optim.Adam(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
betas=oo.betas, eps=oo.eps)
|
||||
elif options.optimizer == "Adamax":
|
||||
optimizer = optim.Adamax(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
betas=oo.betas, eps=oo.eps)
|
||||
elif options.optimizer == "ASGD":
|
||||
optimizer = optim.ASGD(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
lambd=oo.lambd, alpha=oo.alpha, t0=oo.t0)
|
||||
elif options.optimizer == "RMSprop":
|
||||
optimizer = optim.RMSprop(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
eps=oo.eps, alpha=oo.alpha, momentum=oo.momentum, centered=oo.centered)
|
||||
elif options.optimizer == "Rprop":
|
||||
optimizer = optim.Rprop(self.parameters(), lr=initial_rate, etas=oo.etas,
|
||||
step_sizes=oo.step_sizes)
|
||||
elif options.optimizer == "SGD":
|
||||
optimizer = optim.SGD(self.parameters(), lr=initial_rate, weight_decay=oo.weight_decay,
|
||||
momentum=oo.momentum, dampening=oo.dampening, nesterov=oo.nesterov)
|
||||
return optimizer
|
||||
|
||||
def configure_lr(self, options, optimizer, ticks, total_iterations):
|
||||
num_epochs = options.max_epochs
|
||||
learning_rate = options.learning_rate
|
||||
lr_scheduler = options.lr_scheduler
|
||||
lr_min = options.lr_min
|
||||
lr_peaks = options.lr_peaks
|
||||
gamma = options.lr_gamma
|
||||
if not lr_min:
|
||||
lr_min = learning_rate
|
||||
scheduler = None
|
||||
if lr_scheduler == "TriangleLR":
|
||||
steps = lr_peaks * 2 + 1
|
||||
stepsize = num_epochs / steps
|
||||
scheduler = TriangularLR(optimizer, stepsize * ticks, lr_min, learning_rate, gamma)
|
||||
elif lr_scheduler == "CosineAnnealingLR":
|
||||
# divide by odd number to finish on the minimum learning rate
|
||||
cycles = lr_peaks * 2 + 1
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iterations / cycles,
|
||||
eta_min=lr_min)
|
||||
elif lr_scheduler == "ExponentialLR":
|
||||
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma)
|
||||
elif lr_scheduler == "StepLR":
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=options.lr_step_size, gamma=gamma)
|
||||
elif lr_scheduler == "ExponentialResettingLR":
|
||||
reset = (num_epochs * ticks) / 3 # reset at the 1/3 mark.
|
||||
scheduler = ExponentialResettingLR(optimizer, gamma, reset)
|
||||
return scheduler
|
||||
|
||||
def fit(self, training_data, validation_data, options, sparsify=False, device=None, detail=False, run=None):
|
||||
"""
|
||||
Perform the training. This is not called "train" because
|
||||
the base class already defines that method with a different meaning.
|
||||
The base class "train" method puts the Module into "training mode".
|
||||
"""
|
||||
print("Training {} using {} rows of featurized training input...".format(self.name(), training_data.num_rows))
|
||||
|
||||
if training_data.mean is not None:
|
||||
mean = torch.from_numpy(np.array([[training_data.mean]])).to(device)
|
||||
std = torch.from_numpy(np.array([[training_data.std]])).to(device)
|
||||
else:
|
||||
mean = None
|
||||
std = None
|
||||
|
||||
self.normalize(mean, std)
|
||||
|
||||
self.training = True
|
||||
start = time.time()
|
||||
loss_function = nn.NLLLoss()
|
||||
optimizer = self.configure_optimizer(options)
|
||||
print(optimizer)
|
||||
|
||||
num_epochs = options.max_epochs
|
||||
batch_size = options.batch_size
|
||||
trim_level = options.trim_level
|
||||
|
||||
ticks = training_data.num_rows / batch_size # iterations per epoch
|
||||
|
||||
# Calculation of total iterations in non-rolling vs rolling training
|
||||
# ticks = num_rows/batch_size (total number of iterations per epoch)
|
||||
# Non-Rolling Training:
|
||||
# Total Iteration = num_epochs * ticks
|
||||
# Rolling Training:
|
||||
# irl = Initial_rolling_length (We are using 2)
|
||||
# If num_epochs <= max_rolling_length:
|
||||
# Total Iterations = sum(range(irl, irl + num_epochs))
|
||||
# If num_epochs > max_rolling_length:
|
||||
# Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
|
||||
if options.rolling:
|
||||
rolling_length = 2
|
||||
max_rolling_length = int(ticks)
|
||||
if max_rolling_length > options.max_rolling_length + rolling_length:
|
||||
max_rolling_length = options.max_rolling_length + rolling_length
|
||||
bag_count = 100
|
||||
hidden_bag_size = batch_size * bag_count
|
||||
if num_epochs + rolling_length < max_rolling_length:
|
||||
max_rolling_length = num_epochs + rolling_length
|
||||
total_iterations = sum(range(rolling_length, max_rolling_length))
|
||||
if num_epochs + rolling_length > max_rolling_length:
|
||||
epochs_remaining = num_epochs + rolling_length - max_rolling_length
|
||||
total_iterations += epochs_remaining * training_data.num_rows / batch_size
|
||||
ticks = total_iterations / num_epochs
|
||||
else:
|
||||
total_iterations = ticks * num_epochs
|
||||
|
||||
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
|
||||
|
||||
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
||||
log = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.train()
|
||||
if options.rolling:
|
||||
rolling_length += 1
|
||||
if rolling_length <= max_rolling_length:
|
||||
self.init_hidden_bag(hidden_bag_size, device)
|
||||
for i_batch, (audio, labels) in enumerate(training_data.get_data_loader(batch_size)):
|
||||
if not self.batch_first:
|
||||
audio = audio.transpose(1, 0) # GRU wants seq,batch,feature
|
||||
|
||||
if device:
|
||||
self.move_to(device)
|
||||
audio = audio.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
# Also, we need to clear out the hidden state,
|
||||
# detaching it from its history on the last instance.
|
||||
if options.rolling:
|
||||
if rolling_length <= max_rolling_length:
|
||||
if (i_batch + 1) % rolling_length == 0:
|
||||
self.init_hidden()
|
||||
break
|
||||
|
||||
self.rolling_step()
|
||||
else:
|
||||
self.init_hidden()
|
||||
|
||||
self.to(device) # sparsify routines might move param matrices to cpu
|
||||
|
||||
# Before the backward pass, use the optimizer object to zero all of the
|
||||
# gradients for the variables it will update (which are the learnable
|
||||
# weights of the model). This is because by default, gradients are
|
||||
# accumulated in buffers( i.e, not overwritten) whenever .backward()
|
||||
# is called. Checkout docs of torch.autograd.backward for more details.
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Run our forward pass.
|
||||
keyword_scores = self(audio)
|
||||
|
||||
# Compute the loss, gradients
|
||||
loss = loss_function(keyword_scores, labels)
|
||||
|
||||
# Backward pass: compute gradient of the loss with respect to all the learnable
|
||||
# parameters of the model. Internally, the parameters of each Module are stored
|
||||
# in Tensors with requires_grad=True, so this call will compute gradients for
|
||||
# all learnable parameters in the model.
|
||||
loss.backward()
|
||||
# move to next learning rate
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
|
||||
# Calling the step function on an Optimizer makes an update to its parameters
|
||||
# applying the gradients we computed during back propagation
|
||||
optimizer.step()
|
||||
|
||||
if sparsify:
|
||||
if epoch >= num_epochs/3:
|
||||
if epoch < (2*num_epochs)/3:
|
||||
if i_batch % trim_level == 0:
|
||||
self.sparsify()
|
||||
else:
|
||||
self.sparsifyWithSupport()
|
||||
else:
|
||||
self.sparsifyWithSupport()
|
||||
self.to(device) # sparsify routines might move param matrices to cpu
|
||||
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
if detail:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
log += [{'iteration': iteration, 'loss': loss.item(), 'learning_rate': learning_rate}]
|
||||
# Find the best prediction in each sequence and return it's accuracy
|
||||
passed, total, rate = self.evaluate(validation_data, batch_size, device)
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
current_loss = float(loss.item())
|
||||
print("Epoch {}, Loss {:.3f}, Validation Accuracy {:.3f}, Learning Rate {}".format(
|
||||
epoch, current_loss, rate * 100, learning_rate))
|
||||
log += [{'epoch': epoch, 'loss': current_loss, 'accuracy': rate, 'learning_rate': learning_rate}]
|
||||
if run is not None:
|
||||
run.log('progress', epoch / num_epochs)
|
||||
run.log('epoch', epoch)
|
||||
run.log('accuracy', rate)
|
||||
run.log('loss', current_loss)
|
||||
run.log('learning_rate', learning_rate)
|
||||
|
||||
end = time.time()
|
||||
self.training = False
|
||||
print("Trained in {:.2f} seconds".format(end - start))
|
||||
print("Model size {}".format(self.get_model_size()))
|
||||
return log
|
||||
|
||||
def evaluate(self, test_data, batch_size, device=None, outfile=None):
|
||||
"""
|
||||
Evaluate the given test data and print the pass rate
|
||||
"""
|
||||
self.eval()
|
||||
passed = 0
|
||||
total = 0
|
||||
|
||||
self.zero_grad()
|
||||
results = []
|
||||
with torch.no_grad():
|
||||
for i_batch, (audio, labels) in enumerate(test_data.get_data_loader(batch_size)):
|
||||
batch_size = audio.shape[0]
|
||||
audio = audio.transpose(1, 0) # GRU wants seq,batch,feature
|
||||
if device:
|
||||
audio = audio.to(device)
|
||||
labels = labels.to(device)
|
||||
total += batch_size
|
||||
self.init_hidden()
|
||||
keyword_scores = self(audio)
|
||||
last_accuracy, ok, actual = self.batch_accuracy(keyword_scores, labels)
|
||||
results += actual
|
||||
passed += ok
|
||||
|
||||
if outfile:
|
||||
print("Saving evaluation results in '{}'".format(outfile))
|
||||
with open(outfile, "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
return (passed, total, passed / total)
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
"""
|
||||
Featurized Audio in PyTorch Dataset so we can get a DataLoader that is needed for
|
||||
mini-batch training.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, config, keywords, training=False):
|
||||
""" Initialize the AudioDataset from the given *.npz file """
|
||||
self.dataset = np.load(filename)
|
||||
|
||||
# get parameters saved by make_dataset.py
|
||||
parameters = self.dataset["parameters"]
|
||||
self.sample_rate = int(parameters[0])
|
||||
self.audio_size = int(parameters[1])
|
||||
self.input_size = int(parameters[2])
|
||||
self.window_size = int(parameters[3])
|
||||
self.shift = int(parameters[4])
|
||||
self.features = self.dataset["features"].astype(np.float32)
|
||||
self.num_rows = len(self.features)
|
||||
self.features = self.features.reshape((self.num_rows, self.window_size, self.input_size))
|
||||
|
||||
if config.normalize:
|
||||
mean = self.features.mean(axis=0)
|
||||
std = self.features.std(axis=0)
|
||||
self.mean = mean.mean(axis=0).astype(np.float32)
|
||||
std = std.mean(axis=0)
|
||||
# self.std is a divisor, so make sure it contains no zeros
|
||||
self.std = np.array(np.where(std == 0, 1, std)).astype(np.float32)
|
||||
else:
|
||||
self.mean = None
|
||||
self.std = None
|
||||
|
||||
self.label_names = self.dataset["labels"]
|
||||
self.keywords = keywords
|
||||
self.num_keywords = len(self.keywords)
|
||||
self.labels = self.to_long_vector()
|
||||
|
||||
self.keywords_idx = None
|
||||
self.non_keywords_idx = None
|
||||
if training and config.sample_non_kw is not None:
|
||||
self.keywords_idx, self.non_keywords_idx = self.get_keyword_idx(config.sample_non_kw)
|
||||
self.sample_non_kw_probability = config.sample_non_kw_probability
|
||||
|
||||
msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
|
||||
print(msg.format(os.path.basename(filename), self.sample_rate, self.audio_size, self.input_size,
|
||||
self.window_size, self.shift))
|
||||
|
||||
def get_data_loader(self, batch_size):
|
||||
""" Get a DataLoader that can enumerate shuffled batches of data in this dataset """
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
def to_long_vector(self):
|
||||
""" convert the expected labels to a list of integer indexes into the array of keywords """
|
||||
indexer = [(0 if x == "<null>" else self.keywords.index(x)) for x in self.label_names]
|
||||
return np.array(indexer, dtype=np.longlong)
|
||||
|
||||
def get_keyword_idx(self, non_kw_label):
|
||||
""" find the keywords and store there index """
|
||||
indexer = [ids for ids, label in enumerate(self.label_names) if label != non_kw_label]
|
||||
non_indexer = [ids for ids, label in enumerate(self.label_names) if label == non_kw_label]
|
||||
return (np.array(indexer, dtype=np.longlong), np.array(non_indexer, dtype=np.longlong))
|
||||
|
||||
def __len__(self):
|
||||
""" Return the number of rows in this Dataset """
|
||||
if self.non_keywords_idx is None:
|
||||
return self.num_rows
|
||||
else:
|
||||
return int(len(self.keywords_idx) / (1-self.sample_non_kw_probability))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Return a single labelled sample here as a tuple """
|
||||
if self.non_keywords_idx is None:
|
||||
updated_idx=idx
|
||||
else:
|
||||
if idx < len(self.keywords_idx):
|
||||
updated_idx=self.keywords_idx[idx]
|
||||
else:
|
||||
updated_idx=np.random.choice(self.non_keywords_idx)
|
||||
audio = self.features[updated_idx] # batch index is second dimension
|
||||
label = self.labels[updated_idx]
|
||||
sample = (audio, label)
|
||||
return sample
|
||||
|
||||
|
||||
|
||||
def create_model(model_config, input_size, num_keywords):
|
||||
ModelClass = get_model_class(KeywordSpotter)
|
||||
hidden_units_list = [model_config.hidden_units1, model_config.hidden_units2, model_config.hidden_units3]
|
||||
wRank_list = [model_config.wRank1, model_config.wRank2, model_config.wRank3]
|
||||
uRank_list = [model_config.uRank1, model_config.uRank2, model_config.uRank3]
|
||||
wSparsity_list = [model_config.wSparsity, model_config.wSparsity, model_config.wSparsity]
|
||||
uSparsity_list = [model_config.uSparsity, model_config.uSparsity, model_config.uSparsity]
|
||||
print(model_config.gate_nonlinearity, model_config.update_nonlinearity)
|
||||
return ModelClass(model_config.architecture, input_size, model_config.num_layers,
|
||||
hidden_units_list, wRank_list, uRank_list, wSparsity_list,
|
||||
uSparsity_list, model_config.gate_nonlinearity,
|
||||
model_config.update_nonlinearity, num_keywords)
|
||||
|
||||
def save_json(obj, filename):
|
||||
with open(filename, "w") as f:
|
||||
json.dump(obj, f, indent=2)
|
||||
|
||||
|
||||
def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
|
||||
|
||||
filename = config.model.filename
|
||||
categories_file = config.dataset.categories
|
||||
wav_directory = config.dataset.path
|
||||
batch_size = config.training.batch_size
|
||||
hidden_units = config.model.hidden_units
|
||||
architecture = config.model.architecture
|
||||
num_layers = config.model.num_layers
|
||||
use_gpu = config.training.use_gpu
|
||||
|
||||
run = None
|
||||
|
||||
if azureml:
|
||||
from azureml.core.run import Run
|
||||
run = Run.get_context()
|
||||
if run is None:
|
||||
print("### Run.get_context() returned None")
|
||||
else:
|
||||
print("### Running in Azure Context")
|
||||
|
||||
valid_layers = [1, 2, 3]
|
||||
if num_layers not in valid_layers:
|
||||
raise Exception("--num_layers can only be one of these values {}".format(valid_layers))
|
||||
|
||||
if not os.path.isdir(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
if not filename:
|
||||
filename = "{}{}KeywordSpotter.pt".format(architecture, hidden_units)
|
||||
config.model.filename = filename
|
||||
|
||||
# load the featurized data
|
||||
if not os.path.isdir(wav_directory):
|
||||
print("### Error: please specify valid --dataset folder location: {}".format(wav_directory))
|
||||
sys.exit(1)
|
||||
|
||||
if not categories_file:
|
||||
categories_file = os.path.join(wav_directory, "categories.txt")
|
||||
|
||||
with open(categories_file, "r") as f:
|
||||
keywords = [x.strip() for x in f.readlines()]
|
||||
|
||||
training_file = os.path.join(wav_directory, "training_list.npz")
|
||||
testing_file = os.path.join(wav_directory, "testing_list.npz")
|
||||
validation_file = os.path.join(wav_directory, "validation_list.npz")
|
||||
|
||||
if not os.path.isfile(training_file):
|
||||
print("Missing file {}".format(training_file))
|
||||
print("Please run make_datasets.py")
|
||||
sys.exit(1)
|
||||
if not os.path.isfile(validation_file):
|
||||
print("Missing file {}".format(validation_file))
|
||||
print("Please run make_datasets.py")
|
||||
sys.exit(1)
|
||||
if not os.path.isfile(testing_file):
|
||||
print("Missing file {}".format(testing_file))
|
||||
print("Please run make_datasets.py")
|
||||
sys.exit(1)
|
||||
|
||||
model = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if use_gpu:
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
print("### CUDA not available!!")
|
||||
|
||||
print("Loading {}...".format(testing_file))
|
||||
test_data = AudioDataset(testing_file, config.dataset, keywords)
|
||||
|
||||
log = None
|
||||
if not evaluate_only:
|
||||
print("Loading {}...".format(training_file))
|
||||
training_data = AudioDataset(training_file, config.dataset, keywords, training=True)
|
||||
|
||||
print("Loading {}...".format(validation_file))
|
||||
validation_data = AudioDataset(validation_file, config.dataset, keywords)
|
||||
|
||||
if training_data.mean is not None:
|
||||
fname = os.path.join(outdir, "mean.npy")
|
||||
print("Saving {}".format(fname))
|
||||
np.save(fname, training_data.mean)
|
||||
fname = os.path.join(outdir, "std.npy")
|
||||
print("Saving {}".format(fname))
|
||||
np.save(fname, training_data.std)
|
||||
|
||||
# use the training_data mean and std variation
|
||||
test_data.mean = training_data.mean
|
||||
test_data.std = training_data.std
|
||||
validation_data.mean = training_data.mean
|
||||
validation_data.std = training_data.std
|
||||
|
||||
print("Training model {}".format(filename))
|
||||
model = create_model(config.model, training_data.input_size, training_data.num_keywords)
|
||||
if device.type == 'cuda':
|
||||
model.cuda() # move the processing to GPU
|
||||
|
||||
start = time.time()
|
||||
log = model.fit(training_data, validation_data, config.training,
|
||||
config.model.sparsify, device, detail, run)
|
||||
end = time.time()
|
||||
|
||||
passed, total, rate = model.evaluate(training_data, batch_size, device)
|
||||
print("Training accuracy = {:.3f} %".format(rate * 100))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(outdir, filename))
|
||||
|
||||
print("Evaluating {} keyword spotter using {} rows of featurized test audio...".format(
|
||||
architecture, test_data.num_rows))
|
||||
if model is None:
|
||||
msg = "Loading trained model with input size {}, hidden units {} and num keywords {}"
|
||||
print(msg.format(test_data.input_size, hidden_units, test_data.num_keywords))
|
||||
model = create_model(config.model, test_data.input_size, test_data.num_keywords)
|
||||
model.load_dict(torch.load(filename))
|
||||
if model and device.type == 'cuda':
|
||||
model.cuda() # move the processing to GPU
|
||||
|
||||
results_file = os.path.join(outdir, "results.txt")
|
||||
passed, total, rate = model.evaluate(test_data, batch_size, device, results_file)
|
||||
print("Testing accuracy = {:.3f} %".format(rate * 100))
|
||||
|
||||
if not evaluate_only:
|
||||
name = os.path.splitext(filename)[0] + ".onnx"
|
||||
print("saving onnx file: {}".format(name))
|
||||
model.export(os.path.join(outdir, name), device)
|
||||
|
||||
config.dataset.sample_rate = test_data.sample_rate
|
||||
config.dataset.input_size = test_data.audio_size
|
||||
config.dataset.num_filters = test_data.input_size
|
||||
config.dataset.window_size = test_data.window_size
|
||||
config.dataset.shift = test_data.shift
|
||||
|
||||
logdata = {
|
||||
"accuracy_val": rate,
|
||||
"training_time": end - start,
|
||||
"log": log
|
||||
}
|
||||
d = TrainingConfig.to_dict(config)
|
||||
logdata.update(d)
|
||||
|
||||
logname = os.path.join(outdir, "train_results.json")
|
||||
save_json(logdata, logname)
|
||||
|
||||
return rate, log
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if v is None:
|
||||
return False
|
||||
lower = v.lower()
|
||||
return lower in ["t", "1", "true", "yes"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("train a RNN based neural network for keyword spotting")
|
||||
|
||||
# all the training parameters
|
||||
parser.add_argument("--epochs", help="Number of epochs to train", type=int)
|
||||
parser.add_argument("--trim_level", help="Number of batches before sparse support is updated in IHT", type=int)
|
||||
parser.add_argument("--lr_scheduler", help="Type of learning rate scheduler (None, TriangleLR, CosineAnnealingLR,"
|
||||
" ExponentialLR, ExponentialResettingLR)")
|
||||
parser.add_argument("--learning_rate", help="Default learning rate, and maximum for schedulers", type=float)
|
||||
parser.add_argument("--lr_min", help="Minimum learning rate for the schedulers", type=float)
|
||||
parser.add_argument("--lr_peaks", help="Number of peaks for triangle and cosine schedules", type=float)
|
||||
parser.add_argument("--batch_size", "-bs", help="Batch size of training", type=int)
|
||||
parser.add_argument("--architecture", help="Specify model architecture (FastGRNN)")
|
||||
parser.add_argument("--num_layers", type=int, help="Number of RNN layers (1, 2 or 3)")
|
||||
parser.add_argument("--hidden_units", "-hu", type=int, help="Number of hidden units in the FastGRNN layers")
|
||||
parser.add_argument("--hidden_units1", "-hu1", type=int, help="Number of hidden units in the FastGRNN 1st layer")
|
||||
parser.add_argument("--hidden_units2", "-hu2", type=int, help="Number of hidden units in the FastGRNN 2nd layer")
|
||||
parser.add_argument("--hidden_units3", "-hu3", type=int, help="Number of hidden units in the FastGRNN 3rd layer")
|
||||
parser.add_argument("--use_gpu", help="Whether to use fastGRNN for training", action="store_true")
|
||||
parser.add_argument("--normalize", help="Whether to normalize audio dataset", action="store_true")
|
||||
parser.add_argument("--rolling", help="Whether to train model in rolling fashion or not", action="store_true")
|
||||
parser.add_argument("--max_rolling_length", help="Max number of epochs you want to roll the rolling training"
|
||||
" default is 100", type=int)
|
||||
parser.add_argument("--sample_non_kw", "-sl", type=str, help="Sample data for this label with probability sample_prob")
|
||||
parser.add_argument("--sample_non_kw_probability", "-spr", type=float, help="Sample from scl with this probability")
|
||||
|
||||
# arguments for fastgrnn
|
||||
parser.add_argument("--wRank", "-wr", help="Rank of W in 1st layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--uRank", "-ur", help="Rank of U in 1st layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wRank1", "-wr1", help="Rank of W in 1st layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--uRank1", "-ur1", help="Rank of U in 1st layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wRank2", "-wr2", help="Rank of W in 2nd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--uRank2", "-ur2", help="Rank of U in 2nd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wRank3", "-wr3", help="Rank of W in 3rd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--uRank3", "-ur3", help="Rank of U in 3rd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wSparsity", "-wsp", help="Sparsity of W matrices", type=float)
|
||||
parser.add_argument("--uSparsity", "-usp", help="Sparsity of U matrices", type=float)
|
||||
parser.add_argument("--gate_nonlinearity", "-gnl", help="Gate Non-Linearity in FastGRNN default is sigmoid"
|
||||
" use between [sigmoid, quantSigmoid, tanh, quantTanh]")
|
||||
parser.add_argument("--update_nonlinearity", "-unl", help="Update Non-Linearity in FastGRNN default is Tanh"
|
||||
" use between [sigmoid, quantSigmoid, tanh, quantTanh]")
|
||||
|
||||
# or you can just specify an options file.
|
||||
parser.add_argument("--config", help="Use json file containing all these options (as per 'training_config.py')")
|
||||
|
||||
# and some additional stuff ...
|
||||
parser.add_argument("--azureml", help="Tells script we are running in Azure ML context")
|
||||
parser.add_argument("--eval", "-e", help="No training, just evaluate existing model", action='store_true')
|
||||
parser.add_argument("--filename", "-o", help="Name of model file to generate")
|
||||
parser.add_argument("--categories", "-c", help="Name of file containing keywords")
|
||||
parser.add_argument("--dataset", "-a", help="Path to the audio folder containing 'training.npz' file")
|
||||
parser.add_argument("--outdir", help="Folder in which to store output file and log files")
|
||||
parser.add_argument("--detail", "-d", help="Save loss info for every iteration not just every epoch",
|
||||
action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TrainingConfig()
|
||||
if args.config:
|
||||
config.load(args.config)
|
||||
|
||||
azureml = str2bool(args.azureml)
|
||||
|
||||
# then any user defined options overrides these defaults
|
||||
if args.epochs:
|
||||
config.training.max_epochs = args.epochs
|
||||
if args.trim_level:
|
||||
config.training.trim_level = args.trim_level
|
||||
else:
|
||||
config.training.trim_level = 15
|
||||
if args.learning_rate:
|
||||
config.training.learning_rate = args.learning_rate
|
||||
if args.lr_min:
|
||||
config.training.lr_min = args.lr_min
|
||||
if args.lr_peaks:
|
||||
config.training.lr_peaks = args.lr_peaks
|
||||
if args.lr_scheduler:
|
||||
config.training.lr_scheduler = args.lr_scheduler
|
||||
if args.batch_size:
|
||||
config.training.batch_size = args.batch_size
|
||||
if args.rolling:
|
||||
config.training.rolling = args.rolling
|
||||
if args.max_rolling_length:
|
||||
config.training.max_rolling_length = args.max_rolling_length
|
||||
if args.architecture:
|
||||
config.model.architecture = args.architecture
|
||||
if args.num_layers:
|
||||
config.model.num_layers = args.num_layers
|
||||
if args.hidden_units:
|
||||
config.model.hidden_units = args.hidden_units
|
||||
if args.hidden_units1:
|
||||
config.model.hidden_units = args.hidden_units
|
||||
if args.hidden_units2:
|
||||
config.model.hidden_units = args.hidden_units
|
||||
if args.hidden_units3:
|
||||
config.model.hidden_units = args.hidden_units
|
||||
if config.model.num_layers >= 1:
|
||||
if config.model.hidden_units1 is None:
|
||||
config.model.hidden_units1 = config.model.hidden_units
|
||||
if config.model.num_layers >= 2:
|
||||
if config.model.hidden_units2 is None:
|
||||
config.model.hidden_units2 = config.model.hidden_units1
|
||||
if config.model.num_layers == 3:
|
||||
if config.model.hidden_units3 is None:
|
||||
config.model.hidden_units3 = config.model.hidden_units2
|
||||
if args.filename:
|
||||
config.model.filename = args.filename
|
||||
if args.use_gpu:
|
||||
config.training.use_gpu = args.use_gpu
|
||||
if args.normalize:
|
||||
config.dataset.normalize = args.normalize
|
||||
if args.categories:
|
||||
config.dataset.categories = args.categories
|
||||
if args.dataset:
|
||||
config.dataset.path = args.dataset
|
||||
if args.sample_non_kw:
|
||||
config.dataset.sample_non_kw = args.sample_non_kw
|
||||
if args.sample_non_kw_probability is None:
|
||||
config.dataset.sample_non_kw_probability = 0.5
|
||||
else:
|
||||
config.dataset.sample_non_kw_probability = args.sample_non_kw_probability
|
||||
else:
|
||||
config.dataset.sample_non_kw = None
|
||||
|
||||
if args.wRank:
|
||||
config.model.wRank = args.wRank
|
||||
if args.uRank:
|
||||
config.model.uRank = args.wRank
|
||||
if args.wRank1:
|
||||
config.model.wRank1 = args.wRank1
|
||||
if args.uRank1:
|
||||
config.model.uRank1 = args.wRank1
|
||||
if config.model.wRank1 is None:
|
||||
if config.model.wRank is not None:
|
||||
config.model.wRank1 = config.model.wRank
|
||||
if config.model.uRank1 is None:
|
||||
if config.model.uRank is not None:
|
||||
config.model.uRank1 = config.model.uRank
|
||||
if args.wRank2:
|
||||
config.model.wRank2 = args.wRank2
|
||||
if args.uRank2:
|
||||
config.model.uRank2 = args.wRank2
|
||||
if args.wRank3:
|
||||
config.model.wRank3 = args.wRank3
|
||||
if args.uRank3:
|
||||
config.model.uRank3 = args.wRank3
|
||||
if args.wSparsity:
|
||||
config.model.wSparsity = args.wSparsity
|
||||
else:
|
||||
config.model.wSparsity = 1.0
|
||||
if args.uSparsity:
|
||||
config.model.uSparsity = args.uSparsity
|
||||
else:
|
||||
config.model.uSparsity = 1.0
|
||||
if config.model.uSparsity < 1.0 or config.model.wSparsity < 1.0:
|
||||
config.model.sparsify = True
|
||||
else:
|
||||
config.model.sparsify = False
|
||||
if args.gate_nonlinearity:
|
||||
config.model.gate_nonlinearity = args.gate_nonlinearity
|
||||
if args.update_nonlinearity:
|
||||
config.model.update_nonlinearity = args.update_nonlinearity
|
||||
|
||||
if not os.path.isfile("config.json"):
|
||||
config.save("config.json")
|
||||
|
||||
train(config, args.eval, args.outdir, args.detail, azureml)
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# config file for train_classifier.py
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
class ModelOptions:
|
||||
def __init__(self):
|
||||
self.architecture = "FastGRNN"
|
||||
self.num_layers = 1
|
||||
self.hidden_units = None
|
||||
self.hidden_units1 = None
|
||||
self.hidden_units2 = None
|
||||
self.hidden_units3 = None
|
||||
self.filename = ""
|
||||
self.wRank = None
|
||||
self.uRank = None
|
||||
self.wRank1 = None
|
||||
self.uRank1 = None
|
||||
self.wRank2 = None
|
||||
self.uRank2 = None
|
||||
self.wRank3 = None
|
||||
self.uRank3 = None
|
||||
self.gate_nonlinearity = "sigmoid"
|
||||
self.update_nonlinearity = "tanh"
|
||||
|
||||
|
||||
class DatasetOptions:
|
||||
def __init__(self):
|
||||
self.name = "speechcommandsv01"
|
||||
self.featurizer = "featurizer_mel_16000_512_512_80_40_log"
|
||||
self.categories = "categories.txt"
|
||||
self.path = ""
|
||||
self.auto_scale = False
|
||||
self.normalize = False
|
||||
|
||||
|
||||
class OptimizerOptions:
|
||||
def __init__(self):
|
||||
self.weight_decay = 1e-5
|
||||
self.momentum = 0.9 # RMSprop
|
||||
self.centered = False # RMSprop
|
||||
self.alpha = 0 # ASGD, RMSprop
|
||||
self.eps = 1e-8
|
||||
self.rho = 0 # Adadelta
|
||||
self.lr_decay = 0 # Adagrad
|
||||
self.betas = (0.9, 0.999) # Adam, SparseAdam, Adamax
|
||||
self.lambd = 0.0001 # ASGD
|
||||
self.t0 = 1000000.0 # ASGD
|
||||
self.etas = (0.5, 1.2) # Rprop
|
||||
self.dampening = 0 # SGD
|
||||
self.step_sizes = (1e-06, 50) # Rprop
|
||||
self.nesterov = True # SGD
|
||||
|
||||
|
||||
class TrainingOptions:
|
||||
def __init__(self):
|
||||
self.max_epochs = 30
|
||||
self.learning_rate = 1e-2
|
||||
self.lr_scheduler = None
|
||||
self.lr_peaks = 1
|
||||
self.lr_min = 1e-5
|
||||
self.lr_gamma = 1
|
||||
self.lr_step_size = 1
|
||||
self.batch_size = 128
|
||||
self.optimizer = "SGD"
|
||||
self.optimizer_options = OptimizerOptions()
|
||||
self.use_gpu = False
|
||||
self.rolling = False
|
||||
self.max_rolling_length = 100
|
||||
self.decay_step = 200
|
||||
self.decay_rate = 0.1
|
||||
|
||||
|
||||
class TrainingConfig:
|
||||
def __init__(self):
|
||||
self.name = ""
|
||||
self.description = ""
|
||||
self.folder = None
|
||||
|
||||
self.model = ModelOptions()
|
||||
self.dataset = DatasetOptions()
|
||||
self.training = TrainingOptions()
|
||||
|
||||
self.job_id = None
|
||||
self.status = None
|
||||
self.downloaded = False
|
||||
self.last_modified = 0
|
||||
self.retries = 0
|
||||
self.filename = None
|
||||
self.sweep = None
|
||||
|
||||
def set(self, name, value):
|
||||
if name not in self.__dict__:
|
||||
self.__dict__[name] = value
|
||||
else:
|
||||
t = self.__dict__[name]
|
||||
if type(t) == int:
|
||||
self.__dict__[name] = int(value)
|
||||
if type(t) == float:
|
||||
self.__dict__[name] = float(value)
|
||||
if type(t) == str:
|
||||
self.__dict__[name] = str(value)
|
||||
else:
|
||||
self.__dict__[name] = value
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
TrainingConfig.from_dict(self, data)
|
||||
self.filename = filename
|
||||
self.last_modified = os.path.getmtime(self.filename)
|
||||
|
||||
def save(self, filename):
|
||||
""" save an options.json file in the self.filename location """
|
||||
self.filename = filename
|
||||
data = TrainingConfig.to_dict(self)
|
||||
data["model"] = self.model.__dict__
|
||||
with open(filename, "w") as f:
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
self.last_modified = os.path.getmtime(self.filename)
|
||||
|
||||
@staticmethod
|
||||
def to_dict(obj):
|
||||
data = dict(obj.__dict__)
|
||||
for k in data:
|
||||
o = data[k]
|
||||
if hasattr(o, "__dict__"):
|
||||
data[k] = TrainingConfig.to_dict(o)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj, data):
|
||||
for k in data:
|
||||
v = data[k]
|
||||
if not hasattr(obj, k):
|
||||
setattr(obj, k, v)
|
||||
else:
|
||||
if isinstance(v, dict):
|
||||
TrainingConfig.from_dict(getattr(obj, k), v)
|
||||
elif isinstance(getattr(obj, k), tuple):
|
||||
setattr(obj, k, tuple(v))
|
||||
else:
|
||||
setattr(obj, k, v)
|
|
@ -0,0 +1,93 @@
|
|||
# EdgeML FastCells on a sample public dataset
|
||||
|
||||
This directory includes example notebooks and scripts for training
|
||||
FastCells (FastRNN & FastGRNN) along with modified
|
||||
UGRNN, GRU and LSTM to support the LSQ training routine.
|
||||
There is also a sample cleanup and train/test script for the USPS10 public dataset.
|
||||
The subfolder [`KWS-training`](KWS-training) contains code
|
||||
for training a keyword spotting model using a single- or multi-layer RNN.
|
||||
|
||||
|
||||
[`edgeml_pytorch.graph.rnn`](../../../pytorch/pytorch_edgeml/graph/rnn.py)
|
||||
provides two RNN cells **FastRNNCell** and **FastGRNNCell** with additional
|
||||
features like low-rank parameterisation and custom non-linearities. Akin to
|
||||
Bonsai and ProtoNN, the three-phase training routine for FastRNN and FastGRNN
|
||||
is decoupled from the custom cells to facilitate a plug and play behaviour of
|
||||
the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.).
|
||||
Additionally, numerically equivalent CUDA-based implementations FastRNNCuda
|
||||
and FastGRNNCuda are provided for faster training.
|
||||
`edgeml_pytorch.graph.rnn` also contains modified RNN cells of **UGRNNCell**,
|
||||
**GRUCell**, and **LSTMCell**, which can be substituted for Fast(G)RNN,
|
||||
as well as untrolled RNNs which are equivalent to `nn.LSTM` and `nn.GRU`.
|
||||
|
||||
Note that all the cells and wrappers, when used independently from `fastcell_example.py`
|
||||
or `edgeml_pytorch.trainer.fastTrainer`, take in data in a batch first format, i.e.,
|
||||
[batchSize, timeSteps, inputDims] by default, but can also support [timeSteps,
|
||||
batchSize, inputDims] format if `batch_first` argument is set to False.
|
||||
`fast_example.py` automatically adjusts to the correct format across tf, c++ and pytorch.
|
||||
|
||||
For training FastCells, `edgeml_pytorch.trainer.fastTrainer` implements the three-phase
|
||||
FastCell training routine in PyTorch. A simple example `fastcell_example.py` is provided
|
||||
to illustrate its usage. Note that `fastcell_example.py` assumes that data is in a specific format.
|
||||
It is assumed that train and test data is contained in two files, `train.npy` and
|
||||
`test.npy`, each containing a 2D numpy array of dimension `[numberOfExamples,
|
||||
numberOfFeatures]`. numberOfFeatures is `timesteps x inputDims`, flattened
|
||||
across timestep dimension with the input of the first time step followed by the second
|
||||
and so on. For an N-Class problem, we assume the labels are integers from 0
|
||||
through N-1. Lastly, the training data, `train.npy`, is assumed to well shuffled
|
||||
as the training routine doesn't shuffle internally.
|
||||
|
||||
**Tested With:** PyTorch = 1.1 with Python 3.6
|
||||
|
||||
## Download and clean up sample dataset
|
||||
|
||||
To validate the code with USPS dataset, first download and format the dataset to match
|
||||
the required format using the script [fetch_usps.py](fetch_usps.py) and
|
||||
[process_usps.py](process_usps.py)
|
||||
|
||||
```
|
||||
python fetch_usps.py
|
||||
python process_usps.py
|
||||
```
|
||||
|
||||
Note: Even though usps10 is not a time-series dataset, it can be regarding as a time-series
|
||||
dataset where time step sees a new row. So the number of timesteps = 16 and inputDims = 16.
|
||||
|
||||
## Sample command for FastCells on USPS10
|
||||
The following is a sample run on usps10 :
|
||||
|
||||
```bash
|
||||
python fastcell_example.py -dir usps10/ -id 16 -hd 32
|
||||
```
|
||||
This command should give you a final output that reads roughly similar to
|
||||
(might not be exact numbers due to various version mismatches):
|
||||
|
||||
```
|
||||
Maximum Test accuracy at compressed model size(including early stopping): 0.9407075 at Epoch: 262
|
||||
Final Test Accuracy: 0.93721974
|
||||
|
||||
Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False
|
||||
```
|
||||
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or
|
||||
`FastGRNNResults.txt` depending on the choice of the RNN cell. A directory `FastRNNResults` or
|
||||
`FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset.
|
||||
|
||||
Note that the scalars like `alpha`, `beta`, `zeta` and `nu` correspond to the values before
|
||||
the application of the sigmoid function.
|
||||
|
||||
## Byte Quantization(Q) for model compression
|
||||
If you wish to quantize the generated model, use `quantizeFastModels.py`. Usage Instructions:
|
||||
|
||||
```
|
||||
python quantizeFastModels.py -h
|
||||
```
|
||||
|
||||
This will generate quantized models with a suffix of `q` before every param stored in a
|
||||
new directory `QuantizedFastModel` inside the model directory.
|
||||
|
||||
Note that the scalars like `qalpha`, `qbeta`, `qzeta` and `qnu` correspond to values
|
||||
after the application of the sigmoid function over them post quantization;
|
||||
they can be directly plugged into the inference pipleines.
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license.
|
|
@ -5,8 +5,8 @@ import helpermethods
|
|||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
from pytorch_edgeml.graph.rnn import *
|
||||
from pytorch_edgeml.trainer.fastTrainer import FastTrainer
|
||||
from edgeml_pytorch.graph.rnn import *
|
||||
from edgeml_pytorch.trainer.fastTrainer import FastTrainer
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -53,24 +53,24 @@ def main():
|
|||
|
||||
if cell == "FastGRNN":
|
||||
FastCell = FastGRNNCell(inputDims, hiddenDims,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "FastRNN":
|
||||
FastCell = FastRNNCell(inputDims, hiddenDims,
|
||||
update_non_linearity=update_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "UGRNN":
|
||||
FastCell = UGRNNLRCell(inputDims, hiddenDims,
|
||||
update_non_linearity=update_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "GRU":
|
||||
FastCell = GRULRCell(inputDims, hiddenDims,
|
||||
update_non_linearity=update_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "LSTM":
|
||||
FastCell = LSTMLRCell(inputDims, hiddenDims,
|
||||
update_non_linearity=update_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
else:
|
||||
sys.exit('Exiting: No Such Cell as ' + cell)
|
|
@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a
|
|||
command line execution script of ProtoNN developed as part of EdgeML. The
|
||||
example is based on the USPS dataset.
|
||||
|
||||
`pytorch_edgeml.graph.protoNN` implements the ProtoNN prediction functions.
|
||||
`edgeml_pytorch.graph.protoNN` implements the ProtoNN prediction functions.
|
||||
The training routine for ProtoNN is decoupled from the forward graph to
|
||||
facilitate a plug and play behaviour wherein ProtoNN can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs). The
|
||||
training routine is implemented in `pytorch_edgeml.trainer.protoNNTrainer`.
|
||||
training routine is implemented in `edgeml_pytorch.trainer.protoNNTrainer`.
|
||||
(This is also an artifact of consistency requirements with Tensorflow
|
||||
implementation).
|
||||
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import pytorch_edgeml.utils as utils
|
||||
import edgeml_pytorch.utils as utils
|
||||
import argparse
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4):
|
|||
assert A.ndim == 2
|
||||
assert s >= 0
|
||||
assert s <= 1
|
||||
nnz, size, sparse = utils.countnnZ(A, s, bytesPerVar=bytesPerVar)
|
||||
nnz, size, sparse = utils.estimateNNZ(A, s, bytesPerVar=bytesPerVar)
|
||||
nnzList.append(nnz)
|
||||
sizeList.append(size)
|
||||
hasSparse = (hasSparse or sparse)
|
|
@ -17,9 +17,9 @@
|
|||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from pytorch_edgeml.graph.protoNN import ProtoNN\n",
|
||||
"from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils\n",
|
||||
"from edgeml_pytorch.graph.protoNN import ProtoNN\n",
|
||||
"from edgeml_pytorch.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"import edgeml_pytorch.utils as utils\n",
|
||||
"import helpermethods as helper"
|
||||
]
|
||||
},
|
|
@ -5,9 +5,9 @@ from __future__ import print_function
|
|||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from pytorch_edgeml.graph.protoNN import ProtoNN
|
||||
import pytorch_edgeml.utils as utils
|
||||
from edgeml_pytorch.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml_pytorch.graph.protoNN import ProtoNN
|
||||
import edgeml_pytorch.utils as utils
|
||||
import helpermethods as helper
|
||||
import torch
|
||||
|
|
@ -5,9 +5,9 @@ This directory includes an example [notebook](SRNN_Example.ipynb) and a
|
|||
training a simple model on the [Google Speech Commands
|
||||
Dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html).
|
||||
|
||||
`pytorch_edgeml.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use
|
||||
`edgeml_pytorch.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use
|
||||
this with an LSTM cell on this dataset. The training routine for SRNN is
|
||||
implemented in `pytorch_edgeml.trainer.srnnTrainer` and will be used as part of
|
||||
implemented in `edgeml_pytorch.trainer.srnnTrainer` and will be used as part of
|
||||
this example.
|
||||
|
||||
**Tested With:** pytorch > 1.1.0 with Python 2 and Python 3
|
|
@ -39,9 +39,9 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pytorch_edgeml.graph.rnn import SRNN2\n",
|
||||
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils"
|
||||
"from edgeml_pytorch.graph.rnn import SRNN2\n",
|
||||
"from edgeml_pytorch.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import edgeml_pytorch.utils as utils"
|
||||
]
|
||||
},
|
||||
{
|
|
@ -7,9 +7,9 @@ import os
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_edgeml.graph.rnn import SRNN2
|
||||
from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer
|
||||
import pytorch_edgeml.utils as utils
|
||||
from edgeml_pytorch.graph.rnn import SRNN2
|
||||
from edgeml_pytorch.trainer.srnnTrainer import SRNNTrainer
|
||||
import edgeml_pytorch.utils as utils
|
||||
import helpermethods as helper
|
||||
|
||||
config = helper.getSRNN2Args()
|
||||
|
@ -30,6 +30,8 @@ std[std[:] < 0.000001] = 1
|
|||
x_train_ = (x_train_ - mean) / std
|
||||
x_val_ = (x_val_ - mean) / std
|
||||
x_test_ = (x_test_ - mean) / std
|
||||
np.save('mean.npy', mean)
|
||||
np.save('std.npy', std)
|
||||
|
||||
x_train = np.swapaxes(x_train_, 0, 1)
|
||||
x_val = np.swapaxes(x_val_, 0, 1)
|
||||
|
@ -73,3 +75,6 @@ trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)
|
|||
|
||||
trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val,
|
||||
printStep=printStep, valStep=valStep)
|
||||
|
||||
print('Saving trained model:')
|
||||
torch.save(srnn2.state_dict(), 'model_srnn.pt')
|
|
@ -4,7 +4,7 @@ This directory includes, example notebook and general execution script of
|
|||
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
|
||||
use-case on the USPS10 public dataset.
|
||||
|
||||
`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
|
||||
`edgeml_tf.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
|
||||
The three-phase training routine for Bonsai is decoupled from the forward graph
|
||||
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
|
@ -33,8 +33,8 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] =''\n",
|
||||
"\n",
|
||||
"#Bonsai imports\n",
|
||||
"from edgeml.trainer.bonsaiTrainer import BonsaiTrainer\n",
|
||||
"from edgeml.graph.bonsai import Bonsai\n",
|
||||
"from edgeml_tf.trainer.bonsaiTrainer import BonsaiTrainer\n",
|
||||
"from edgeml_tf.graph.bonsai import Bonsai\n",
|
||||
"\n",
|
||||
"# Fixing seeds for reproducibility\n",
|
||||
"tf.set_random_seed(42)\n",
|
|
@ -5,8 +5,8 @@ import helpermethods
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
from edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml.graph.bonsai import Bonsai
|
||||
from edgeml_tf.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml_tf.graph.bonsai import Bonsai
|
||||
|
||||
|
||||
def main():
|
|
@ -33,15 +33,14 @@
|
|||
"import numpy as np\n",
|
||||
"# Making sure edgeml is part of python path\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
|
||||
"\n",
|
||||
"np.random.seed(42)\n",
|
||||
"tf.set_random_seed(42)\n",
|
||||
"\n",
|
||||
"# MI-RNN and EMI-RNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml_tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml_tf.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml_tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml_tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -120,10 +119,10 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"x_train shape is: (6294, 6, 48, 9)\n",
|
||||
"y_train shape is: (6294, 6, 6)\n",
|
||||
"x_test shape is: (1058, 6, 48, 9)\n",
|
||||
"y_test shape is: (1058, 6, 6)\n"
|
||||
"x_train shape is: (6409, 6, 48, 9)\n",
|
||||
"y_train shape is: (6409, 6, 6)\n",
|
||||
"x_test shape is: (943, 6, 48, 9)\n",
|
||||
"y_test shape is: (943, 6, 6)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -272,6 +271,11 @@
|
|||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": []
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
|
@ -279,36 +283,54 @@
|
|||
"Update policy: top-k\n",
|
||||
"Training with MI-RNN loss for 3 rounds\n",
|
||||
"Round: 0\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00467 Acc 0.90104 | Val acc 0.90454 | Model saved to /tmp/model-lstm, global_step 1000\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00262 Acc 0.93750 | Val acc 0.91777 | Model saved to /tmp/model-lstm, global_step 1001\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00285 Acc 0.90625 | Val acc 0.91871 | Model saved to /tmp/model-lstm, global_step 1002\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00263 Acc 0.91146 | Val acc 0.92344 | Model saved to /tmp/model-lstm, global_step 1003\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1003\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00375 Acc 0.95312 | Val acc 0.91304 | Model saved to /tmp/model-lstm, global_step 1000\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00279 Acc 0.96875 | Val acc 0.91835 | Model saved to /tmp/model-lstm, global_step 1001\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00262 Acc 0.96875 | Val acc 0.89077 | Model saved to /tmp/model-lstm, global_step 1002\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00255 Acc 0.96875 | Val acc 0.90668 | "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": []
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model saved to /tmp/model-lstm, global_step 1003\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": []
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Round: 1\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00260 Acc 0.91146 | Val acc 0.92155 | Model saved to /tmp/model-lstm, global_step 1004\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00268 Acc 0.91146 | Val acc 0.92628 | Model saved to /tmp/model-lstm, global_step 1005\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00241 Acc 0.92188 | Val acc 0.92911 | Model saved to /tmp/model-lstm, global_step 1006\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00245 Acc 0.91667 | Val acc 0.91493 | Model saved to /tmp/model-lstm, global_step 1007\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1006\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00262 Acc 0.96875 | Val acc 0.89077 | Model saved to /tmp/model-lstm, global_step 1004\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00255 Acc 0.96875 | Val acc 0.90668 | Model saved to /tmp/model-lstm, global_step 1005\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00247 Acc 0.96875 | Val acc 0.89183 | Model saved to /tmp/model-lstm, global_step 1006\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00253 Acc 0.96875 | Val acc 0.91410 | Model saved to /tmp/model-lstm, global_step 1007\n",
|
||||
"Round: 2\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00245 Acc 0.91667 | Val acc 0.91493 | Model saved to /tmp/model-lstm, global_step 1008\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00247 Acc 0.93750 | Val acc 0.91210 | Model saved to /tmp/model-lstm, global_step 1009\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00238 Acc 0.93750 | Val acc 0.91115 | Model saved to /tmp/model-lstm, global_step 1010\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.00247 Acc 0.91667 | Val acc 0.90737 | Model saved to /tmp/model-lstm, global_step 1011\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1008\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00249 Acc 0.96875 | Val acc 0.90456 | Model saved to /tmp/model-lstm, global_step 1008\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00239 Acc 0.96875 | Val acc 0.89714 | Model saved to /tmp/model-lstm, global_step 1009\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00255 Acc 0.96354 | Val acc 0.91516 | Model saved to /tmp/model-lstm, global_step 1010\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.00232 Acc 0.96875 | Val acc 0.91092 | Model saved to /tmp/model-lstm, global_step 1011\n",
|
||||
"Round: 3\n",
|
||||
"Switching to EMI-Loss function\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.19644 Acc 0.92188 | Val acc 0.91304 | Model saved to /tmp/model-lstm, global_step 1012\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.19590 Acc 0.92188 | Val acc 0.91304 | Model saved to /tmp/model-lstm, global_step 1013\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.18886 Acc 0.91667 | Val acc 0.92250 | Model saved to /tmp/model-lstm, global_step 1014\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.17789 Acc 0.92708 | Val acc 0.91210 | Model saved to /tmp/model-lstm, global_step 1015\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1014\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.23041 Acc 0.96875 | Val acc 0.89608 | Model saved to /tmp/model-lstm, global_step 1012\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.20689 Acc 0.96875 | Val acc 0.89396 | Model saved to /tmp/model-lstm, global_step 1013\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.19695 Acc 0.96875 | Val acc 0.90562 | Model saved to /tmp/model-lstm, global_step 1014\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.18891 Acc 0.96875 | Val acc 0.89608 | Model saved to /tmp/model-lstm, global_step 1015\n",
|
||||
"Round: 4\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.17789 Acc 0.92708 | Val acc 0.91210 | Model saved to /tmp/model-lstm, global_step 1016\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.17308 Acc 0.93750 | Val acc 0.90737 | Model saved to /tmp/model-lstm, global_step 1017\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.16609 Acc 0.93750 | Val acc 0.91682 | Model saved to /tmp/model-lstm, global_step 1018\n",
|
||||
"Epoch 1 Batch 193 ( 390) Loss 0.16253 Acc 0.93750 | Val acc 0.91115 | Model saved to /tmp/model-lstm, global_step 1019\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1018\n"
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.18891 Acc 0.96875 | Val acc 0.89608 | Model saved to /tmp/model-lstm, global_step 1016\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.17931 Acc 0.96875 | Val acc 0.90456 | Model saved to /tmp/model-lstm, global_step 1017\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.17625 Acc 0.96875 | Val acc 0.90138 | Model saved to /tmp/model-lstm, global_step 1018\n",
|
||||
"Epoch 1 Batch 189 ( 390) Loss 0.16728 Acc 0.96875 | Val acc 0.93319 | Model saved to /tmp/model-lstm, global_step 1019\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -394,10 +416,10 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy at k = 2: 0.894130\n",
|
||||
"Accuracy at k = 2: 0.898880\n",
|
||||
"Savings due to MI-RNN : 0.625000\n",
|
||||
"Savings due to Early prediction: 0.696645\n",
|
||||
"Total Savings: 0.886242\n"
|
||||
"Savings due to Early prediction: 0.623507\n",
|
||||
"Total Savings: 0.858815\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -431,27 +453,27 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
" len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
|
||||
"0 1 0.869019 0.868122 0.878439 0.871618 0.869019 0.869019 \n",
|
||||
"1 2 0.894130 0.894577 0.898733 0.896412 0.894130 0.894130 \n",
|
||||
"2 3 0.893451 0.893584 0.897599 0.894880 0.893451 0.893451 \n",
|
||||
"3 4 0.873770 0.873280 0.882517 0.873642 0.873770 0.873770 \n",
|
||||
"4 5 0.853410 0.853243 0.870547 0.851575 0.853410 0.853410 \n",
|
||||
"5 6 0.836105 0.836891 0.863634 0.833071 0.836105 0.836105 \n",
|
||||
"0 1 0.881235 0.881088 0.887293 0.885055 0.881235 0.881235 \n",
|
||||
"1 2 0.898880 0.899793 0.901265 0.902492 0.898880 0.898880 \n",
|
||||
"2 3 0.902952 0.904039 0.903778 0.906140 0.902952 0.902952 \n",
|
||||
"3 4 0.888700 0.889662 0.892359 0.890378 0.888700 0.888700 \n",
|
||||
"4 5 0.873431 0.874679 0.882768 0.874046 0.873431 0.873431 \n",
|
||||
"5 6 0.860197 0.862230 0.877873 0.859560 0.860197 0.860197 \n",
|
||||
"\n",
|
||||
" micro-rec \n",
|
||||
"0 0.869019 \n",
|
||||
"1 0.894130 \n",
|
||||
"2 0.893451 \n",
|
||||
"3 0.873770 \n",
|
||||
"4 0.853410 \n",
|
||||
"5 0.836105 \n",
|
||||
"Max accuracy 0.894130 at subsequencelength 2\n",
|
||||
"Max micro-f 0.894130 at subsequencelength 2\n",
|
||||
"Micro-precision 0.894130 at subsequencelength 2\n",
|
||||
"Micro-recall 0.894130 at subsequencelength 2\n",
|
||||
"Max macro-f 0.894577 at subsequencelength 2\n",
|
||||
"macro-precision 0.898733 at subsequencelength 2\n",
|
||||
"macro-recall 0.896412 at subsequencelength 2\n"
|
||||
"0 0.881235 \n",
|
||||
"1 0.898880 \n",
|
||||
"2 0.902952 \n",
|
||||
"3 0.888700 \n",
|
||||
"4 0.873431 \n",
|
||||
"5 0.860197 \n",
|
||||
"Max accuracy 0.902952 at subsequencelength 3\n",
|
||||
"Max micro-f 0.902952 at subsequencelength 3\n",
|
||||
"Micro-precision 0.902952 at subsequencelength 3\n",
|
||||
"Micro-recall 0.902952 at subsequencelength 3\n",
|
||||
"Max macro-f 0.904039 at subsequencelength 3\n",
|
||||
"macro-precision 0.903778 at subsequencelength 3\n",
|
||||
"macro-recall 0.906140 at subsequencelength 3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -483,16 +505,11 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1003\n",
|
||||
"Round: 0, Validation accuracy: 0.9234, Test Accuracy (k = 2): 0.899559, Total Savings: 0.790902\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1006\n",
|
||||
"Round: 1, Validation accuracy: 0.9291, Test Accuracy (k = 2): 0.896844, Total Savings: 0.814705\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1008\n",
|
||||
"Round: 2, Validation accuracy: 0.9149, Test Accuracy (k = 2): 0.894469, Total Savings: 0.821671\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1014\n",
|
||||
"Round: 3, Validation accuracy: 0.9225, Test Accuracy (k = 2): 0.894130, Total Savings: 0.876447\n",
|
||||
"INFO:tensorflow:Restoring parameters from /tmp/model-lstm-1018\n",
|
||||
"Round: 4, Validation accuracy: 0.9168, Test Accuracy (k = 2): 0.894130, Total Savings: 0.886242\n"
|
||||
"Round: 0, Validation accuracy: 0.9183, Test Accuracy (k = 2): 0.916865, Total Savings: 0.765207\n",
|
||||
"Round: 1, Validation accuracy: 0.9141, Test Accuracy (k = 2): 0.915507, Total Savings: 0.789403\n",
|
||||
"Round: 2, Validation accuracy: 0.9152, Test Accuracy (k = 2): 0.908381, Total Savings: 0.799538\n",
|
||||
"Round: 3, Validation accuracy: 0.9056, Test Accuracy (k = 2): 0.903970, Total Savings: 0.844050\n",
|
||||
"Round: 4, Validation accuracy: 0.9332, Test Accuracy (k = 2): 0.898880, Total Savings: 0.858815\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -523,9 +540,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": ".tf",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": ".tf"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -537,7 +554,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
|
@ -38,10 +38,10 @@
|
|||
"tf.set_random_seed(42)\n",
|
||||
"\n",
|
||||
"# MI-RNN and EMI-RNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml_tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml_tf.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml_tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml_tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -598,7 +598,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
|
@ -34,11 +34,11 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] ='1'\n",
|
||||
"\n",
|
||||
"# FastGRNN and FastRNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_FastGRNN\n",
|
||||
"from edgeml.graph.rnn import EMI_FastRNN\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml_tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml_tf.graph.rnn import EMI_FastGRNN\n",
|
||||
"from edgeml_tf.graph.rnn import EMI_FastRNN\n",
|
||||
"from edgeml_tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml_tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
|
@ -35,10 +35,10 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
|
||||
"\n",
|
||||
"# MI-RNN and EMI-RNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml_tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml_tf.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml_tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml_tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -654,7 +654,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.3"
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
|
@ -3,7 +3,7 @@
|
|||
This directory includes example notebooks EMI-RNN developed as part of EdgeML.
|
||||
The example is based on the UCI Human Activity Recognition dataset.
|
||||
|
||||
Please refer to `tf/docs/EMI-RNN.md` for detailed documentation of EMI-RNN.
|
||||
Please refer to `docs/EMI-RNN.md` for detailed documentation of EMI-RNN.
|
||||
|
||||
Please refer to `00_emi_lstm_example.ipynb` for a quick and dirty getting
|
||||
started guide.
|
0
tf/examples/EMI-RNN/img/3PartsGraph.png → examples/tf/EMI-RNN/img/3PartsGraph.png
Executable file → Normal file
До Ширина: | Высота: | Размер: 25 KiB После Ширина: | Высота: | Размер: 25 KiB |
0
tf/examples/EMI-RNN/img/MIML_illustration.png → examples/tf/EMI-RNN/img/MIML_illustration.png
Executable file → Normal file
До Ширина: | Высота: | Размер: 23 KiB После Ширина: | Высота: | Размер: 23 KiB |
|
@ -5,7 +5,7 @@ FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
|
|||
UGRNN, GRU and LSTM to support the LSQ training routine.
|
||||
Also, we include a sample cleanup and use-case on the USPS10 public dataset.
|
||||
|
||||
`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
|
||||
`edgeml_tf.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
|
||||
multiple additional features like Low-Rank parameterisation, custom
|
||||
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
|
||||
routine for FastRNN and FastGRNN is decoupled from the custom cells to
|
||||
|
@ -14,9 +14,9 @@ architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `G
|
|||
`edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)),
|
||||
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible.
|
||||
|
||||
For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase
|
||||
For training FastCells, `edgeml_tf.trainer.fastTrainer` implements the three-phase
|
||||
FastCell training routine in Tensorflow. A simple example,
|
||||
`examples/fastcell_example.py` is provided to illustrate its usage.
|
||||
`examples/tf/fastcell_example.py` is provided to illustrate its usage.
|
||||
|
||||
Note that `fastcell_example.py` assumes that data is in a specific format. It
|
||||
is assumed that train and test data is contained in two files, `train.npy` and
|
|
@ -28,12 +28,12 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] =''\n",
|
||||
"\n",
|
||||
"#FastRNN and FastGRNN imports\n",
|
||||
"from edgeml.trainer.fastTrainer import FastTrainer\n",
|
||||
"from edgeml.graph.rnn import FastGRNNCell\n",
|
||||
"from edgeml.graph.rnn import FastRNNCell\n",
|
||||
"from edgeml.graph.rnn import UGRNNLRCell\n",
|
||||
"from edgeml.graph.rnn import GRULRCell\n",
|
||||
"from edgeml.graph.rnn import LSTMLRCell\n",
|
||||
"from edgeml_tf.trainer.fastTrainer import FastTrainer\n",
|
||||
"from edgeml_tf.graph.rnn import FastGRNNCell\n",
|
||||
"from edgeml_tf.graph.rnn import FastRNNCell\n",
|
||||
"from edgeml_tf.graph.rnn import UGRNNLRCell\n",
|
||||
"from edgeml_tf.graph.rnn import GRULRCell\n",
|
||||
"from edgeml_tf.graph.rnn import LSTMLRCell\n",
|
||||
"\n",
|
||||
"# Fixing seeds for reproducibility\n",
|
||||
"tf.set_random_seed(42)\n",
|
|
@ -6,12 +6,12 @@ import tensorflow as tf
|
|||
import numpy as np
|
||||
import sys
|
||||
|
||||
from edgeml.trainer.fastTrainer import FastTrainer
|
||||
from edgeml.graph.rnn import FastGRNNCell
|
||||
from edgeml.graph.rnn import FastRNNCell
|
||||
from edgeml.graph.rnn import UGRNNLRCell
|
||||
from edgeml.graph.rnn import GRULRCell
|
||||
from edgeml.graph.rnn import LSTMLRCell
|
||||
from edgeml_tf.trainer.fastTrainer import FastTrainer
|
||||
from edgeml_tf.graph.rnn import FastGRNNCell
|
||||
from edgeml_tf.graph.rnn import FastRNNCell
|
||||
from edgeml_tf.graph.rnn import UGRNNLRCell
|
||||
from edgeml_tf.graph.rnn import GRULRCell
|
||||
from edgeml_tf.graph.rnn import LSTMLRCell
|
||||
|
||||
|
||||
def main():
|
|
@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a
|
|||
command line execution script of ProtoNN developed as part of EdgeML. The
|
||||
example is based on the USPS dataset.
|
||||
|
||||
`edgeml.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow.
|
||||
`edgeml_tf.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow.
|
||||
The training routine for ProtoNN is decoupled from the forward graph to
|
||||
facilitate a plug and play behaviour wherein ProtoNN can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs). The
|
||||
training routine is implemented in `edgeml.trainer.protoNNTrainer`.
|
||||
training routine is implemented in `edgeml_tf.trainer.protoNNTrainer`.
|
||||
|
||||
Note that, `protoNN_example.py` assumes the data to be in a specific format. It
|
||||
is assumed that train and test data is contained in two files, `train.npy` and
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import edgeml.utils as utils
|
||||
import edgeml_tf.utils as utils
|
||||
import argparse
|
||||
|
||||
|
|
@ -29,9 +29,9 @@
|
|||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"from edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"from edgeml.graph.protoNN import ProtoNN\n",
|
||||
"import edgeml.utils as utils\n",
|
||||
"from edgeml_tf.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"from edgeml_tf.graph.protoNN import ProtoNN\n",
|
||||
"import edgeml_tf.utils as utils\n",
|
||||
"import helpermethods as helper"
|
||||
]
|
||||
},
|
|
@ -6,9 +6,9 @@ import sys
|
|||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml.graph.protoNN import ProtoNN
|
||||
import edgeml.utils as utils
|
||||
from edgeml_tf.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml_tf.graph.protoNN import ProtoNN
|
||||
import edgeml_tf.utils as utils
|
||||
import helpermethods as helper
|
||||
|
||||
|
|
@ -1,27 +1,48 @@
|
|||
## Edge Machine Learning: PyTorch Library
|
||||
## Edge Machine Learning: Pytorch Library
|
||||
|
||||
This directory includes, PyTorch implementations of various techniques and
|
||||
algorithms developed as part of EdgeML. Currently, the following algorithms are
|
||||
available in PyTorch:
|
||||
This package includes PyTorch implementations of following algorithms and training
|
||||
techniques developed as part of EdgeML. The PyTorch graphs for the forward/backward
|
||||
pass of these algorithms are packaged as `edgeml_pytorch.graph` and the trainers
|
||||
for these algorithms are in `edgeml_pytorch.trainer`.
|
||||
|
||||
1. [Bonsai](../docs/publications/Bonsai.pdf)
|
||||
2. [FastRNN & FastGRNN](../docs/publications/FastGRNN.pdf)
|
||||
1. [Bonsai](https://github.com/microsoft/EdgeML/blob/master/docs/publications/Bonsai.pdf): `edgeml_pytorch.graph.bonsai` implements
|
||||
the Bonsai prediction graph. The three-phase training routine for Bonsai is decoupled
|
||||
from the forward graph to facilitate a plug and play behaviour wherein Bonsai can be
|
||||
combined with or used as a final layer classifier for other architectures (RNNs, CNNs).
|
||||
See `edgeml_pytorch.trainer.bonsaiTrainer` for 3-phase training.
|
||||
2. [ProtoNN](https://github.com/microsoft/EdgeML/blob/master/docs/publications/ProtoNN.pdf): `edgeml_pytorch.graph.protoNN` implements the
|
||||
ProtoNN prediction functions. The training routine for ProtoNN is decoupled from the forward
|
||||
graph to facilitate a plug and play behaviour wherein ProtoNN can be combined with or used
|
||||
as a final layer classifier for other architectures (RNNs, CNNs). The training routine is
|
||||
implemented in `edgeml_pytorch.trainer.protoNNTrainer`.
|
||||
3. [FastRNN & FastGRNN](https://github.com/microsoft/EdgeML/blob/master/docs/publications/FastGRNN.pdf): `edgeml_pytorch.graph.rnn` provides
|
||||
various RNN cells --- including new cells `FastRNNCell` and `FastGRNNCell` as well as
|
||||
`UGRNNCell`, `GRUCell`, and `LSTMCell` --- with features like low-rank parameterisation
|
||||
of weight matrices and custom non-linearities. Akin to Bonsai and ProtoNN, the three-phase
|
||||
training routine for FastRNN and FastGRNN is decoupled from the custom cells to enable plug and
|
||||
play behaviour of the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.).
|
||||
Additionally, numerically equivalent CUDA-based implementations `FastRNNCUDACell` and
|
||||
`FastGRNNCUDACell` are provided for faster training. `edgeml_pytorch.graph.rnn`.
|
||||
`edgeml_pytorch.graph.rnn.Fast(G)RNN(CUDA)` provides unrolled RNNs equivalent to `nn.LSTM` and `nn.GRU`.
|
||||
`edgeml_pytorch.trainer.fastmodel` presents a sample multi-layer RNN + multi-class classifier model.
|
||||
4. [S-RNN](https://github.com/microsoft/EdgeML/blob/master/docs/publications/SRNN.pdf): `edgeml_pytorch.graph.rnn.SRNN2` implements a
|
||||
2 layer SRNN network which can be instantied with a choice of RNN cell. The training
|
||||
routine for SRNN is in `edgeml_pytorch.trainer.srnnTrainer`.
|
||||
|
||||
Usage directions and examples notebooks for this package are provided [here](https://github.com/microsoft/EdgeML/blobl/master/examples/pytorch).
|
||||
|
||||
The PyTorch compute graphs for these algoriths are packaged as
|
||||
`pytorch_edgeml.graph`. Trainers for these algorithms are in `pytorch_edgeml.trainer`. Usage
|
||||
directions and examples for these algorithms are provided in `examples`
|
||||
directory. To get started with any of the provided algorithms, please follow
|
||||
the notebooks in the the `examples` directory.
|
||||
|
||||
## Installation
|
||||
|
||||
Use pip and the provided requirements file to first install required
|
||||
dependencies before installing the `pytorch_edgeml` library. Details for installation provided below.
|
||||
|
||||
It is highly recommended that EdgeML be installed in a virtual environment. Please create
|
||||
a new virtual environment using your environment manager ([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)).
|
||||
It is highly recommended that EdgeML be installed in a virtual environment.
|
||||
Please create a new virtual environment using your environment manager
|
||||
([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or
|
||||
[Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)).
|
||||
Make sure the new environment is active before running the below mentioned commands.
|
||||
|
||||
Use pip to install requirements before installing the `edgeml_pytorch` library.
|
||||
Details for cpu based installation and gpu based installation provided below.
|
||||
|
||||
### CPU
|
||||
|
||||
```
|
||||
|
@ -29,18 +50,16 @@ pip install -r requirements-cpu.txt
|
|||
pip install -e .
|
||||
```
|
||||
|
||||
Tested on Python 3.6 with PyTorch 1.1.
|
||||
Tested on Python3.6 with >= PyTorch 1.1.0.
|
||||
|
||||
### GPU
|
||||
|
||||
Install appropriate CUDA and cuDNN [Tested with >= CUDA 9.0 and cuDNN >= 7.0]
|
||||
Install appropriate CUDA and cuDNN [Tested with >= CUDA 8.1 and cuDNN >= 6.1]
|
||||
|
||||
```
|
||||
pip install -r requirements-gpu.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Note: If the above commands don't go through for PyTorch installation on CPU and GPU, please follow this [link](https://pytorch.org/get-started/locally/).
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license.
|
||||
|
|