зеркало из https://github.com/microsoft/EdgeML.git
Merge RNNPool Codes (#201)
* added bidirectional * bidirectional in BaseRNN * updated all rnn for new bidirectional * debugging for fastgrnncuda * fix for fastgrnncuda * visual wakeword * visual wakewords evaluation * visual wakeword evaluation * updated readme for eval * face detection * update face detection * rnn edit * test update * model loading change * update readme * update eval tools in readme * update train * readme * readme * readme * requirements * requirements * readme * rnn * update s3fd_net * update s3fd_net * update s3fd_net * train * train * add additional args * add additional args * arg changes in wider_test * added arg for using new ckpt * readme * remove old modelsupport * readme * readme * requirements * readme * eval on all format * support for calculating MAP * update readme * Update README.md * readme * readme * readme * readme * fix for warnings * readme * readme scores * add dump weights and traces support * readme * remove eval warnings * eval remove import warnings * readme changes * readme changes * support for qvga monochrome * readme update * readme update * Update README.md * readme update * environment key update * config files * update both config files text * change architecture * readme update * quantized cpp rnnpool * Update README.md * Update README.md * Update README.md * smaller model for qvga * Update RPool_Face_QVGA_monochrome.py * update to ssd code * update to init * update to init * update to dataloader * tf code for face detection * tf code for face detection * add tf face detection code * eval file * fix weights and detect function * Delete factory.py * Update RPool_Face_QVGA_monochrome.py * Update RPool_Face_QVGA_monochrome.py * Update RPool_Face_C.py * Update RPool_Face_Quant.py * Update augmentations.py * Update detection.py * Update eval.py * vww updates * removed tf code * Update fastcell_example.py * Update widerface.py * Update rnnpool.py * Update fastTrainer.py * Update fastTrainer.py * Update fastTrainer.py * Update model_mobilenet_2rnnpool.py * Update model_mobilenet_rnnpool.py * Delete top_level.txt * Delete dependency_links.txt * remove egg-info * Update fastcell_example.py * file copyright edits * Update README.md * delete output blank file * remove input trace file Co-authored-by: Harsha Vardhan Simhadri <harsha-simhadri@users.noreply.github.com> Co-authored-by: Ubuntu <harshasi@GPUnode1.n14uw44gbsdu3cvfvcchcfucod.xx.internal.cloudapp.net> Co-authored-by: Harsha Vardhan Simhadri <harshasi@microsoft.com>
This commit is contained in:
Родитель
2c89850fff
Коммит
373a7a14ee
|
@ -0,0 +1,146 @@
|
|||
# Code for Face Detection experiments with RNNPool
|
||||
## Requirements
|
||||
1. Follow instructions to install requirements for EdgeML operators and the EdgeML operators [here](https://github.com/microsoft/EdgeML/blob/master/pytorch/README.md).
|
||||
2. Install requirements for face detection model using
|
||||
``` pip install -r requirements.txt ```
|
||||
We have tested the installation and the code on Ubuntu 18.04 with Cuda 10.2 and CuDNN 7.6
|
||||
|
||||
## Dataset
|
||||
1. Download WIDER face dataset images and annotations from http://shuoyang1213.me/WIDERFACE/ and place them all in a folder with name 'WIDER_FACE'. That is, download WIDER_train.zip, WIDER_test.zip, WIDER_val.zip, wider_face_split.zip and place it in WIDER_FACE folder, and unzip files using:
|
||||
|
||||
```shell
|
||||
cd WIDER_FACE
|
||||
unzip WIDER_train.zip
|
||||
unzip WIDER_test.zip
|
||||
unzip WIDER_val.zip
|
||||
unzip wider_face_split.zip
|
||||
cd ..
|
||||
|
||||
```
|
||||
|
||||
2. In `data/config.py` , set _C.HOME to the parent directory of the above folder, and set the _C.FACE.WIDER_DIR to the folder path.
|
||||
That is, if the WIDER_FACE folder is created in /mnt folder, then _C.HOME='/mnt'
|
||||
_C.FACE.WIDER_DIR='/mnt/WIDER_FACE'.
|
||||
Similarly, change `data/config_qvga.py` to set _C.HOME and _C.FACE.WIDER_DIR.
|
||||
3. Run
|
||||
``` python prepare_wider_data.py ```
|
||||
|
||||
|
||||
# Usage
|
||||
|
||||
## Training
|
||||
|
||||
```shell
|
||||
|
||||
IS_QVGA_MONO=0 python train.py --batch_size 32 --model_arch RPool_Face_Quant --cuda True --multigpu True --save_folder weights/ --epochs 300 --save_frequency 5000
|
||||
|
||||
```
|
||||
|
||||
For QVGA:
|
||||
```shell
|
||||
|
||||
IS_QVGA_MONO=1 python train.py --batch_size 64 --model_arch RPool_Face_QVGA_monochrome --cuda True --multigpu True --save_folder weights/ --epochs 300 --save_frequency 5000
|
||||
|
||||
```
|
||||
This will save checkpoints after every '--save_frequency' number of iterations in a weight file with 'checkpoint.pth' at the end and weights for the best state in a file with 'best_state.pth' at the end. These will be saved in '--save_folder'. For resuming training from a checkpoint, use '--resume <checkpoint_name>.pth' with the above command. For example,
|
||||
|
||||
|
||||
```shell
|
||||
|
||||
IS_QVGA_MONO=1 python train.py --batch_size 64 --model_arch RPool_Face_QVGA_monochrome --cuda True --multigpu True --save_folder weights/ --epochs 300 --save_frequency 5000 --resume <checkpoint_name>.pth
|
||||
|
||||
```
|
||||
|
||||
If IS_QVGA_MONO is 0 then training input images will be 640x640 and RGB.
|
||||
If IS_QVGA_MONO is 1 then training input images will be 320x320 and converted to monochrome.
|
||||
|
||||
Input images for training models are cropped and reshaped to square to maintain consistency with [S3FD](https://arxiv.org/abs/1708.05237). However testing can be done on any size of images, thus we resize testing input image size to have area equal to VGA (640x480)/QVGA (320x240), so that aspect ratio is not changed.
|
||||
|
||||
The architecture RPool_Face_QVGA_monochrome is for QVGA monochrome format while RPool_Face_C and RPool_Face_Quant are for VGA RGB format.
|
||||
|
||||
|
||||
## Test
|
||||
There are two modes of testing the trained model -- the evaluation mode to generate bounding boxes for a set of sample images, and the test mode to compute statistics like mAP scores.
|
||||
|
||||
#### Evaluation Mode
|
||||
|
||||
Given a set of images in <your_image_folder>, `eval/py` generates bounding boxes around faces (where the confidence is higher than certain threshold) and write the images in <your_save_folder>. To evaluate the `rpool_face_best_state.pth` model (stored in ./weights), execute the following command:
|
||||
|
||||
```shell
|
||||
IS_QVGA_MONO=0 python eval.py --model_arch RPool_Face_Quant --model ./weights/RPool_Face_Quant_best_state.pth --image_folder <your_image_folder> --save_dir <your_save_folder>
|
||||
```
|
||||
|
||||
For QVGA:
|
||||
```shell
|
||||
IS_QVGA_MONO=1 python eval.py --model_arch RPool_Face_QVGA_monochrome --model ./weights/RPool_Face_QVGA_monochrome_best_state.pth --image_folder <your_image_folder> --save_dir <your_save_folder>
|
||||
```
|
||||
|
||||
This will save images in <your_save_folder> with bounding boxes around faces, where the confidence is high. Here is an example image with a single bounding box.
|
||||
|
||||
![Camera: Himax0360](imrgb20ft.png)
|
||||
|
||||
If IS_QVGA_MONO=0 the evaluation code accepts an image of any size and resizes it to 640x480x3 while preserving original image aspect ratio.
|
||||
|
||||
If IS_QVGA_MONO=1 the evaluation code accepts an image of any size and resizes and converts it to monochrome to make it 320x240x1 while preserving original image aspect ratio.
|
||||
|
||||
#### WIDER Set Test
|
||||
In this mode, we test the generated model against the provided WIDER_FACE validation and test dataset.
|
||||
|
||||
For this, first run the following to generate predictions of the model and store output in the '--save_folder' folder.
|
||||
|
||||
```shell
|
||||
IS_QVGA_MONO=0 python wider_test.py --model_arch RPool_Face_Quant --model ./weights/RPool_Face_Quant_best_state.pth --save_folder rpool_face_quant_val --subset val
|
||||
```
|
||||
|
||||
For QVGA:
|
||||
```shell
|
||||
IS_QVGA_MONO=1 python wider_test.py --model_arch RPool_Face_QVGA_monochrome --model ./weights/RPool_Face_QVGA_monochrome_best_state.pth --save_folder rpool_face_qvgamono_val --subset val
|
||||
```
|
||||
|
||||
The above command generates predictions for each image in the "validation" dataset. For each image, a separate prediction file is provided (image_name.txt file in appropriate folder). The first line of the prediction file contains the total number of boxes identified.
|
||||
Then each line in the file corresponds to an identified box. For each box, five numbers are generated: length of the box, height of the box, x-axis offset, y-axis offset, confidence value for presence of a face in the box.
|
||||
|
||||
If IS_QVGA_MONO=1 then testing is done by converting images to monochrome and QVGA, else if IS_QVGA_MONO=0 then testing is done on VGA RGB images.
|
||||
|
||||
The architecture RPool_Face_QVGA_monochrome is for QVGA monochrome format while RPool_Face_C and RPool_Face_Quant are for VGA RGB format.
|
||||
|
||||
###### For calculating MAP scores:
|
||||
Now using these boxes, we can compute the standard MAP score that is widely used in this literature (see [here](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173) for more details) as follows:
|
||||
|
||||
1. Download eval_tools.zip from http://shuoyang1213.me/WIDERFACE/support/eval_script/eval_tools.zip and unzip in a folder of same name in this directory.
|
||||
|
||||
Example code:
|
||||
|
||||
```shell
|
||||
wget http://shuoyang1213.me/WIDERFACE/support/eval_script/eval_tools.zip
|
||||
unzip eval_tools.zip
|
||||
```
|
||||
|
||||
2. Set up scripts to use the Matlab '.mat' data files in eval_tools/ground_truth folder for MAP calculation: The following installs python files that provide the same functionality as the '.m' matlab scripts in eval_tools folder.
|
||||
```
|
||||
cd eval_tools
|
||||
git clone https://github.com/wondervictor/WiderFace-Evaluation.git
|
||||
cd WiderFace-Evaluation
|
||||
python3 setup.py build_ext --inplace
|
||||
```
|
||||
|
||||
3. Run ```python3 evaluation.py -p <your_save_folder> -g <groud truth dir>``` in WiderFace-Evaluation folder
|
||||
|
||||
where `prediction_dir` is the '--save_folder' used for `wider_test.py` above and <groud truth dir> is the subfolder `eval_tools/ground_truth`. That is in, WiderFace-Evaluation directory, run:
|
||||
|
||||
```shell
|
||||
python3 evaluation.py -p <your_save_folder> -g ../ground_truth
|
||||
```
|
||||
This script should output the MAP for the WIDER-easy, WIDER-medium, and WIDER-hard subsets of the dataset. Our best performance using RPool_Face_Quant model is: 0.80 (WIDER-easy), 0.78 (WIDER-medium), 0.53 (WIDER-hard).
|
||||
|
||||
|
||||
##### Dump RNNPool Input Output Traces and Weights
|
||||
|
||||
To save model weights and/or input output pairs for each patch through RNNPool in numpy format use the command below. Put images which you want to save traces for in <your_image_folder> . Specify output folder for saving model weights in numpy format in <your_save_model_numpy_folder>. Specify output folder for saving input output traces of RNNPool in numpy format in <your_save_traces_numpy_folder>. Note that input traces will be saved in a folder named 'inputs' and output traces in a folder named 'outputs' inside <your_save_traces_numpy_folder>.
|
||||
|
||||
```shell
|
||||
python3 dump_model.py --model ./weights/RPool_Face_QVGA_monochrome_best_state.pth --model_arch RPool_Face_Quant --image_folder <your_image_folder> --save_model_npy_dir <your_save_model_numpy_folder> --save_traces_npy_dir <your_save_traces_numpy_folder>
|
||||
```
|
||||
If you wish to save only model weights, do not specify --save_traces_npy_dir. If you wish to save only traces do not specify --save_model_npy_dir.
|
||||
|
||||
Code has been built upon https://github.com/yxlijun/S3FD.pytorch
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .widerface import WIDERDetection
|
||||
|
||||
from data.choose_config import cfg
|
||||
cfg = cfg.cfg
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def detection_collate(batch):
|
||||
"""Custom collate fn for dealing with batches of images that have a different
|
||||
number of associated object annotations (bounding boxes).
|
||||
|
||||
Arguments:
|
||||
batch: (tuple) A tuple of tensor images and lists of annotations
|
||||
|
||||
Return:
|
||||
A tuple containing:
|
||||
1) (tensor) batch of images stacked on their 0 dim
|
||||
2) (list of tensors) annotations for a given image are stacked on
|
||||
0 dim
|
||||
"""
|
||||
targets = []
|
||||
imgs = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
targets.append(torch.FloatTensor(sample[1]))
|
||||
return torch.stack(imgs, 0), targets
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from importlib import import_module
|
||||
|
||||
IS_QVGA_MONO = os.environ['IS_QVGA_MONO']
|
||||
|
||||
|
||||
name = 'config'
|
||||
if IS_QVGA_MONO == '1':
|
||||
name = name + '_qvga'
|
||||
|
||||
|
||||
cfg = import_module('data.' + name)
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from easydict import EasyDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
_C = EasyDict()
|
||||
cfg = _C
|
||||
# data augument config
|
||||
_C.expand_prob = 0.5
|
||||
_C.expand_max_ratio = 4
|
||||
_C.hue_prob = 0.5
|
||||
_C.hue_delta = 18
|
||||
_C.contrast_prob = 0.5
|
||||
_C.contrast_delta = 0.5
|
||||
_C.saturation_prob = 0.5
|
||||
_C.saturation_delta = 0.5
|
||||
_C.brightness_prob = 0.5
|
||||
_C.brightness_delta = 0.125
|
||||
_C.data_anchor_sampling_prob = 0.5
|
||||
_C.min_face_size = 6.0
|
||||
_C.apply_distort = True
|
||||
_C.apply_expand = False
|
||||
_C.img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype(
|
||||
'float32')
|
||||
_C.resize_width = 640
|
||||
_C.resize_height = 640
|
||||
_C.scale = 1 / 127.0
|
||||
_C.anchor_sampling = True
|
||||
_C.filter_min_face = True
|
||||
|
||||
|
||||
_C.IS_MONOCHROME = False
|
||||
|
||||
|
||||
# anchor config
|
||||
_C.FEATURE_MAPS = [160, 80, 40, 20, 10, 5]
|
||||
_C.INPUT_SIZE = 640
|
||||
_C.STEPS = [4, 8, 16, 32, 64, 128]
|
||||
_C.ANCHOR_SIZES = [16, 32, 64, 128, 256, 512]
|
||||
_C.CLIP = False
|
||||
_C.VARIANCE = [0.1, 0.2]
|
||||
|
||||
# detection config
|
||||
_C.NMS_THRESH = 0.3
|
||||
_C.NMS_TOP_K = 5000
|
||||
_C.TOP_K = 750
|
||||
_C.CONF_THRESH = 0.01
|
||||
|
||||
# loss config
|
||||
_C.NEG_POS_RATIOS = 3
|
||||
_C.NUM_CLASSES = 2
|
||||
_C.USE_NMS = True
|
||||
|
||||
# dataset config
|
||||
_C.HOME = '/mnt/' ## change here ----------
|
||||
|
||||
# face config
|
||||
_C.FACE = EasyDict()
|
||||
_C.FACE.TRAIN_FILE = './data/face_train.txt'
|
||||
_C.FACE.VAL_FILE = './data/face_val.txt'
|
||||
_C.FACE.WIDER_DIR = '/mnt/WIDER_FACE' ## change here ---------
|
||||
_C.FACE.OVERLAP_THRESH = [0.1, 0.35, 0.5]
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from easydict import EasyDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
_C = EasyDict()
|
||||
cfg = _C
|
||||
# data augument config
|
||||
_C.expand_prob = 0.5
|
||||
_C.expand_max_ratio = 2
|
||||
_C.hue_prob = 0.5
|
||||
_C.hue_delta = 18
|
||||
_C.contrast_prob = 0.5
|
||||
_C.contrast_delta = 0.5
|
||||
_C.saturation_prob = 0.5
|
||||
_C.saturation_delta = 0.5
|
||||
_C.brightness_prob = 0.5
|
||||
_C.brightness_delta = 0.125
|
||||
_C.data_anchor_sampling_prob = 0.5
|
||||
_C.min_face_size = 1.0
|
||||
_C.apply_distort = True
|
||||
_C.apply_expand = False
|
||||
_C.img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype(
|
||||
'float32')
|
||||
_C.resize_width = 320
|
||||
_C.resize_height = 320
|
||||
_C.scale = 1 / 127.0
|
||||
_C.anchor_sampling = True
|
||||
_C.filter_min_face = True
|
||||
|
||||
|
||||
_C.IS_MONOCHROME = True
|
||||
|
||||
# anchor config
|
||||
_C.FEATURE_MAPS = [40, 40, 20, 20]
|
||||
_C.INPUT_SIZE = 320
|
||||
_C.STEPS = [8, 8, 16, 16]
|
||||
_C.ANCHOR_SIZES = [8, 16, 32, 48]
|
||||
_C.CLIP = False
|
||||
_C.VARIANCE = [0.1, 0.2]
|
||||
|
||||
# detection config
|
||||
_C.NMS_THRESH = 0.3
|
||||
_C.NMS_TOP_K = 5000
|
||||
_C.TOP_K = 750
|
||||
_C.CONF_THRESH = 0.05
|
||||
|
||||
# loss config
|
||||
_C.NEG_POS_RATIOS = 3
|
||||
_C.NUM_CLASSES = 2
|
||||
_C.USE_NMS = True
|
||||
|
||||
# dataset config
|
||||
_C.HOME = '/mnt/'
|
||||
|
||||
# face config
|
||||
_C.FACE = EasyDict()
|
||||
_C.FACE.TRAIN_FILE = './data/face_train.txt'
|
||||
_C.FACE.VAL_FILE = './data/face_val.txt'
|
||||
_C.FACE.WIDER_DIR = '/mnt/WIDER_FACE'
|
||||
_C.FACE.OVERLAP_THRESH = [0.1, 0.35, 0.5]
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
import torch.utils.data as data
|
||||
import numpy as np
|
||||
import random
|
||||
import sys; sys.path.append('../')
|
||||
from utils.augmentations import preprocess
|
||||
|
||||
|
||||
class WIDERDetection(data.Dataset):
|
||||
"""docstring for WIDERDetection"""
|
||||
|
||||
def __init__(self, list_file, mode='train', mono_mode=False):
|
||||
super(WIDERDetection, self).__init__()
|
||||
self.mode = mode
|
||||
self.mono_mode = mono_mode
|
||||
self.fnames = []
|
||||
self.boxes = []
|
||||
self.labels = []
|
||||
|
||||
with open(list_file) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip().split()
|
||||
num_faces = int(line[1])
|
||||
box = []
|
||||
label = []
|
||||
for i in range(num_faces):
|
||||
x = float(line[2 + 5 * i])
|
||||
y = float(line[3 + 5 * i])
|
||||
w = float(line[4 + 5 * i])
|
||||
h = float(line[5 + 5 * i])
|
||||
c = int(line[6 + 5 * i])
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
box.append([x, y, x + w, y + h])
|
||||
label.append(c)
|
||||
if len(box) > 0:
|
||||
self.fnames.append(line[0])
|
||||
self.boxes.append(box)
|
||||
self.labels.append(label)
|
||||
|
||||
self.num_samples = len(self.boxes)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target, h, w = self.pull_item(index)
|
||||
return img, target
|
||||
|
||||
def pull_item(self, index):
|
||||
while True:
|
||||
image_path = self.fnames[index]
|
||||
img = Image.open(image_path)
|
||||
if img.mode == 'L':
|
||||
img = img.convert('RGB')
|
||||
|
||||
im_width, im_height = img.size
|
||||
boxes = self.annotransform(
|
||||
np.array(self.boxes[index]), im_width, im_height)
|
||||
label = np.array(self.labels[index])
|
||||
bbox_labels = np.hstack((label[:, np.newaxis], boxes)).tolist()
|
||||
img, sample_labels = preprocess(
|
||||
img, bbox_labels, self.mode, image_path)
|
||||
sample_labels = np.array(sample_labels)
|
||||
if len(sample_labels) > 0:
|
||||
target = np.hstack(
|
||||
(sample_labels[:, 1:], sample_labels[:, 0][:, np.newaxis]))
|
||||
|
||||
assert (target[:, 2] > target[:, 0]).any()
|
||||
assert (target[:, 3] > target[:, 1]).any()
|
||||
break
|
||||
else:
|
||||
index = random.randrange(0, self.num_samples)
|
||||
|
||||
|
||||
if self.mono_mode==True:
|
||||
im = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
|
||||
return torch.from_numpy(np.expand_dims(im,axis=0)), target, im_height, im_width
|
||||
|
||||
return torch.from_numpy(img), target, im_height, im_width
|
||||
|
||||
|
||||
def annotransform(self, boxes, im_width, im_height):
|
||||
boxes[:, 0] /= im_width
|
||||
boxes[:, 1] /= im_height
|
||||
boxes[:, 2] /= im_width
|
||||
boxes[:, 3] /= im_height
|
||||
return boxes
|
||||
|
||||
|
||||
def detection_collate(batch):
|
||||
"""Custom collate fn for dealing with batches of images that have a different
|
||||
number of associated object annotations (bounding boxes).
|
||||
|
||||
Arguments:
|
||||
batch: (tuple) A tuple of tensor images and lists of annotations
|
||||
|
||||
Return:
|
||||
A tuple containing:
|
||||
1) (tensor) batch of images stacked on their 0 dim
|
||||
2) (list of tensors) annotations for a given image are stacked on
|
||||
0 dim
|
||||
"""
|
||||
targets = []
|
||||
imgs = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
targets.append(torch.FloatTensor(sample[1]))
|
||||
return torch.stack(imgs, 0), targets
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as data
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
from data.config import cfg
|
||||
from torch.autograd import Variable
|
||||
from utils.augmentations import to_chw_bgr
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='face detection dump')
|
||||
|
||||
parser.add_argument('--model', type=str,
|
||||
default='weights/rpool_face_c.pth', help='trained model')
|
||||
#small_fgrnn_smallram_sd.pth', help='trained model')
|
||||
parser.add_argument('--model_arch',
|
||||
default='RPool_Face_C', type=str,
|
||||
choices=['RPool_Face_C', 'RPool_Face_B', 'RPool_Face_A', 'RPool_Face_Quant'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--image_folder', default=None, type=str, help='folder containing images')
|
||||
parser.add_argument('--save_model_npy_dir', default=None, type=str, help='Directory for saving model in numpy array format')
|
||||
parser.add_argument('--save_traces_npy_dir', default=None, type=str, help='Directory for saving RNNPool input and output traces in numpy array format')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
if use_cuda:
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
else:
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
|
||||
|
||||
|
||||
def saveModelNpy(net):
|
||||
if os.path.isdir(args.save_model_npy_dir) is False:
|
||||
try:
|
||||
os.mkdir(args.save_model_npy_dir)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % args.save_model_npy_dir)
|
||||
return
|
||||
|
||||
np.save(args.save_model_npy_dir+'/W1.npy', net.rnn_model.cell_rnn.cell.W.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/W2.npy', net.rnn_model.cell_bidirrnn.cell.W.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/U1.npy', net.rnn_model.cell_rnn.cell.U.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/U2.npy', net.rnn_model.cell_bidirrnn.cell.U.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/Bg1.npy', net.rnn_model.cell_rnn.cell.bias_gate.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/Bg2.npy', net.rnn_model.cell_bidirrnn.cell.bias_gate.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/Bh1.npy', net.rnn_model.cell_rnn.cell.bias_update.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/Bh2.npy', net.rnn_model.cell_bidirrnn.cell.bias_update.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/nu1.npy', net.rnn_model.cell_rnn.cell.nu.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/nu2.npy', net.rnn_model.cell_bidirrnn.cell.nu.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/zeta1.npy', net.rnn_model.cell_rnn.cell.zeta.cpu().detach().numpy())
|
||||
np.save(args.save_model_npy_dir+'/zeta2.npy', net.rnn_model.cell_bidirrnn.cell.zeta.cpu().detach().numpy())
|
||||
|
||||
|
||||
|
||||
activation = {}
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activation[name] = output.detach()
|
||||
return hook
|
||||
|
||||
def saveTracesNpy(net, img_list):
|
||||
if os.path.isdir(args.save_traces_npy_dir) is False:
|
||||
try:
|
||||
os.mkdir(args.save_traces_npy_dir)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % args.save_traces_npy_dir)
|
||||
return
|
||||
|
||||
if os.path.isdir(os.path.join(args.save_traces_npy_dir,'inputs')) is False:
|
||||
try:
|
||||
os.mkdir(os.path.join(args.save_traces_npy_dir,'inputs'))
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % os.path.join(args.save_traces_npy_dir,'inputs'))
|
||||
return
|
||||
|
||||
if os.path.isdir(os.path.join(args.save_traces_npy_dir,'outputs')) is False:
|
||||
try:
|
||||
os.mkdir(os.path.join(args.save_traces_npy_dir,'outputs'))
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % os.path.join(args.save_traces_npy_dir,'outputs'))
|
||||
return
|
||||
|
||||
inputDims = net.rnn_model.inputDims
|
||||
nRows = net.rnn_model.nRows
|
||||
nCols = net.rnn_model.nCols
|
||||
count=0
|
||||
for img_path in img_list:
|
||||
img = Image.open(os.path.join(args.image_folder, img_path))
|
||||
|
||||
img = img.convert('RGB')
|
||||
|
||||
img = np.array(img)
|
||||
max_im_shrink = np.sqrt(
|
||||
640 * 480 / (img.shape[0] * img.shape[1]))
|
||||
image = cv2.resize(img, None, None, fx=max_im_shrink,
|
||||
fy=max_im_shrink, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
x = to_chw_bgr(image)
|
||||
x = x.astype('float32')
|
||||
x -= cfg.img_mean
|
||||
x = x[[2, 1, 0], :, :]
|
||||
|
||||
x = Variable(torch.from_numpy(x).unsqueeze(0))
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
t1 = time.time()
|
||||
y = net(x)
|
||||
|
||||
|
||||
patches = activation['prepatch']
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,inputDims,nRows,nCols))
|
||||
|
||||
rnnX = activation['rnn_model']
|
||||
|
||||
patches_all = torch.stack(torch.split(patches, split_size_or_sections=1, dim=0),dim=-1)
|
||||
rnnX_all = torch.stack(torch.split(rnnX, split_size_or_sections=1, dim=0),dim=-1)
|
||||
|
||||
for k in range(patches_all.shape[-1]):
|
||||
patches_tosave = patches_all[0,:,:,:,k].cpu().numpy().transpose(1,2,0)
|
||||
rnnX_tosave = rnnX_all[0,:,k].cpu().numpy()
|
||||
np.save(args.save_traces_npy_dir+'/inputs/trace_'+str(count)+'_'+str(k)+'.npy', patches_tosave)
|
||||
np.save(args.save_traces_npy_dir+'/outputs/trace_'+str(count)+'_'+str(k)+'.npy', rnnX_tosave)
|
||||
|
||||
count+=1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
module = import_module('models.' + args.model_arch)
|
||||
net = module.build_s3fd('test', cfg.NUM_CLASSES)
|
||||
|
||||
# net = torch.nn.DataParallel(net)
|
||||
|
||||
checkpoint_dict = torch.load(args.model)
|
||||
|
||||
model_dict = net.state_dict()
|
||||
|
||||
|
||||
model_dict.update(checkpoint_dict)
|
||||
net.load_state_dict(model_dict)
|
||||
|
||||
|
||||
|
||||
net.eval()
|
||||
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
cudnn.benckmark = True
|
||||
|
||||
|
||||
|
||||
if args.save_model_npy_dir is not None:
|
||||
saveModelNpy(net)
|
||||
|
||||
if args.save_traces_npy_dir is not None:
|
||||
net.unfold.register_forward_hook(get_activation('prepatch'))
|
||||
net.rnn_model.register_forward_hook(get_activation('rnn_model'))
|
||||
img_path = args.image_folder
|
||||
img_list = [os.path.join(img_path, x)
|
||||
for x in os.listdir(img_path)]
|
||||
saveTracesNpy(net, img_list)
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as data
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
from data.choose_config import cfg
|
||||
cfg = cfg.cfg
|
||||
|
||||
from utils.augmentations import to_chw_bgr
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='face detection demo')
|
||||
parser.add_argument('--save_dir', type=str, default='results/',
|
||||
help='Directory for detect result')
|
||||
parser.add_argument('--model', type=str,
|
||||
default='weights/rpool_face_c.pth', help='trained model')
|
||||
parser.add_argument('--thresh', default=0.17, type=float,
|
||||
help='Final confidence threshold')
|
||||
parser.add_argument('--model_arch',
|
||||
default='RPool_Face_C', type=str,
|
||||
choices=['RPool_Face_C', 'RPool_Face_Quant', 'RPool_Face_QVGA_monochrome'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--image_folder', default=None, type=str, help='folder containing images')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.makedirs(args.save_dir)
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
if use_cuda:
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
else:
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
|
||||
|
||||
def detect(net, img_path, thresh):
|
||||
img = Image.open(img_path)
|
||||
img = img.convert('RGB')
|
||||
img = np.array(img)
|
||||
height, width, _ = img.shape
|
||||
|
||||
if os.environ['IS_QVGA_MONO'] == '1':
|
||||
max_im_shrink = np.sqrt(
|
||||
320 * 240 / (img.shape[0] * img.shape[1]))
|
||||
else:
|
||||
max_im_shrink = np.sqrt(
|
||||
640 * 480 / (img.shape[0] * img.shape[1]))
|
||||
|
||||
image = cv2.resize(img, None, None, fx=max_im_shrink,
|
||||
fy=max_im_shrink, interpolation=cv2.INTER_LINEAR)
|
||||
# img = cv2.resize(img, (640, 640))
|
||||
x = to_chw_bgr(image)
|
||||
x = x.astype('float32')
|
||||
x -= cfg.img_mean
|
||||
x = x[[2, 1, 0], :, :]
|
||||
|
||||
|
||||
if cfg.IS_MONOCHROME == True:
|
||||
x = 0.299 * x[0] + 0.587 * x[1] + 0.114 * x[2]
|
||||
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
x = torch.from_numpy(x).unsqueeze(0)
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
t1 = time.time()
|
||||
y = net(x)
|
||||
detections = y.data
|
||||
scale = torch.Tensor([img.shape[1], img.shape[0],
|
||||
img.shape[1], img.shape[0]])
|
||||
|
||||
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
|
||||
for i in range(detections.size(1)):
|
||||
j = 0
|
||||
while detections[0, i, j, 0] >= thresh:
|
||||
score = detections[0, i, j, 0]
|
||||
pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
|
||||
left_up, right_bottom = (pt[0], pt[1]), (pt[2], pt[3])
|
||||
j += 1
|
||||
cv2.rectangle(img, left_up, right_bottom, (0, 0, 255), 2)
|
||||
conf = "{:.3f}".format(score)
|
||||
point = (int(left_up[0]), int(left_up[1] - 5))
|
||||
cv2.putText(img, conf, point, cv2.FONT_HERSHEY_COMPLEX,
|
||||
0.6, (0, 255, 0), 1)
|
||||
|
||||
t2 = time.time()
|
||||
print('detect:{} timer:{}'.format(img_path, t2 - t1))
|
||||
|
||||
cv2.imwrite(os.path.join(args.save_dir, os.path.basename(img_path)), img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
module = import_module('models.' + args.model_arch)
|
||||
net = module.build_s3fd('test', cfg.NUM_CLASSES)
|
||||
|
||||
net = torch.nn.DataParallel(net)
|
||||
|
||||
checkpoint_dict = torch.load(args.model)
|
||||
|
||||
model_dict = net.state_dict()
|
||||
|
||||
|
||||
model_dict.update(checkpoint_dict)
|
||||
net.load_state_dict(model_dict)
|
||||
|
||||
net.eval()
|
||||
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
cudnn.benckmark = True
|
||||
|
||||
img_path = args.image_folder
|
||||
img_list = [os.path.join(img_path, x)
|
||||
for x in os.listdir(img_path)]
|
||||
for path in img_list:
|
||||
detect(net, path, args.thresh)
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 390 KiB |
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .functions import *
|
||||
from .modules import *
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def point_form(boxes):
|
||||
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
||||
representation for comparison to point form ground truth data.
|
||||
Args:
|
||||
boxes: (tensor) center-size default boxes from priorbox layers.
|
||||
Return:
|
||||
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||
"""
|
||||
return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
|
||||
boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmax, ymax
|
||||
|
||||
|
||||
def center_size(boxes):
|
||||
""" Convert prior_boxes to (cx, cy, w, h)
|
||||
representation for comparison to center-size form ground truth data.
|
||||
Args:
|
||||
boxes: (tensor) point_form boxes
|
||||
Return:
|
||||
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
||||
"""
|
||||
return torch.cat([(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
|
||||
boxes[:, 2:] - boxes[:, :2]], 1) # w, h
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
""" We resize both tensors to [A,B,2] without new malloc:
|
||||
[A,2] -> [A,1,2] -> [A,B,2]
|
||||
[B,2] -> [1,B,2] -> [A,B,2]
|
||||
Then we compute the area of intersect between box_a and box_b.
|
||||
Args:
|
||||
box_a: (tensor) bounding boxes, Shape: [A,4].
|
||||
box_b: (tensor) bounding boxes, Shape: [B,4].
|
||||
Return:
|
||||
(tensor) intersection area, Shape: [A,B].
|
||||
"""
|
||||
A = box_a.size(0)
|
||||
B = box_b.size(0)
|
||||
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
|
||||
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
||||
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
|
||||
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
||||
inter = torch.clamp((max_xy - min_xy), min=0)
|
||||
return inter[:, :, 0] * inter[:, :, 1]
|
||||
|
||||
|
||||
def jaccard(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
||||
is simply the intersection over union of two boxes. Here we operate on
|
||||
ground truth boxes and default boxes.
|
||||
E.g.:
|
||||
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||
Args:
|
||||
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
||||
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
||||
Return:
|
||||
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
||||
"""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||
(box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
||||
area_b = ((box_b[:, 2] - box_b[:, 0]) *
|
||||
(box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
||||
union = area_a + area_b - inter
|
||||
return inter / union # [A,B]
|
||||
|
||||
|
||||
def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
|
||||
"""Match each prior box with the ground truth box of the highest jaccard
|
||||
overlap, encode the bounding boxes, then return the matched indices
|
||||
corresponding to both confidence and location preds.
|
||||
Args:
|
||||
threshold: (float) The overlap threshold used when mathing boxes.
|
||||
truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
|
||||
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
||||
variances: (tensor) Variances corresponding to each prior coord,
|
||||
Shape: [num_priors, 4].
|
||||
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
||||
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
|
||||
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
||||
idx: (int) current batch index
|
||||
Return:
|
||||
The matched indices corresponding to 1)location and 2)confidence preds.
|
||||
"""
|
||||
# jaccard index
|
||||
overlaps = jaccard(
|
||||
truths,
|
||||
point_form(priors)
|
||||
)
|
||||
# (Bipartite Matching)
|
||||
# [1,num_objects] best prior for each ground truth
|
||||
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
||||
# [1,num_priors] best ground truth for each prior
|
||||
best_truth_overlap, best_truth_idx = overlaps.max(
|
||||
0, keepdim=True) # 0-2000
|
||||
best_truth_idx.squeeze_(0)
|
||||
best_truth_overlap.squeeze_(0)
|
||||
best_prior_idx.squeeze_(1)
|
||||
best_prior_overlap.squeeze_(1)
|
||||
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
|
||||
# TODO refactor: index best_prior_idx with long tensor
|
||||
# ensure every gt matches with its prior of max overlap
|
||||
for j in range(best_prior_idx.size(0)):
|
||||
best_truth_idx[best_prior_idx[j]] = j
|
||||
_th1, _th2, _th3 = threshold # _th1 = 0.1 ,_th2 = 0.35,_th3 = 0.5
|
||||
|
||||
N = (torch.sum(best_prior_overlap >= _th2) +
|
||||
torch.sum(best_prior_overlap >= _th3)) // 2
|
||||
matches = truths[best_truth_idx] # Shape: [num_priors,4]
|
||||
conf = labels[best_truth_idx] # Shape: [num_priors]
|
||||
conf[best_truth_overlap < _th2] = 0 # label as background
|
||||
|
||||
best_truth_overlap_clone = best_truth_overlap.clone()
|
||||
add_idx = best_truth_overlap_clone.gt(
|
||||
_th1).eq(best_truth_overlap_clone.lt(_th2))
|
||||
best_truth_overlap_clone[~add_idx] = 0
|
||||
stage2_overlap, stage2_idx = best_truth_overlap_clone.sort(descending=True)
|
||||
|
||||
stage2_overlap = stage2_overlap.gt(_th1)
|
||||
|
||||
if N > 0:
|
||||
N = torch.sum(stage2_overlap[:N]) if torch.sum(
|
||||
stage2_overlap[:N]) < N else N
|
||||
conf[stage2_idx[:N]] += 1
|
||||
|
||||
loc = encode(matches, priors, variances)
|
||||
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
||||
conf_t[idx] = conf # [num_priors] top class label for each prior
|
||||
|
||||
|
||||
def match_ssd(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
|
||||
"""Match each prior box with the ground truth box of the highest jaccard
|
||||
overlap, encode the bounding boxes, then return the matched indices
|
||||
corresponding to both confidence and location preds.
|
||||
Args:
|
||||
threshold: (float) The overlap threshold used when mathing boxes.
|
||||
truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
|
||||
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
||||
variances: (tensor) Variances corresponding to each prior coord,
|
||||
Shape: [num_priors, 4].
|
||||
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
||||
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
|
||||
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
||||
idx: (int) current batch index
|
||||
Return:
|
||||
The matched indices corresponding to 1)location and 2)confidence preds.
|
||||
"""
|
||||
# jaccard index
|
||||
overlaps = jaccard(
|
||||
truths,
|
||||
point_form(priors)
|
||||
)
|
||||
# (Bipartite Matching)
|
||||
# [1,num_objects] best prior for each ground truth
|
||||
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
||||
# [1,num_priors] best ground truth for each prior
|
||||
best_truth_overlap, best_truth_idx = overlaps.max(
|
||||
0, keepdim=True) # 0-2000
|
||||
best_truth_idx.squeeze_(0)
|
||||
best_truth_overlap.squeeze_(0)
|
||||
best_prior_idx.squeeze_(1)
|
||||
best_prior_overlap.squeeze_(1)
|
||||
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
|
||||
# TODO refactor: index best_prior_idx with long tensor
|
||||
# ensure every gt matches with its prior of max overlap
|
||||
for j in range(best_prior_idx.size(0)):
|
||||
best_truth_idx[best_prior_idx[j]] = j
|
||||
matches = truths[best_truth_idx] # Shape: [num_priors,4]
|
||||
conf = labels[best_truth_idx] # Shape: [num_priors]
|
||||
conf[best_truth_overlap < threshold] = 0 # label as background
|
||||
loc = encode(matches, priors, variances)
|
||||
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
||||
conf_t[idx] = conf # [num_priors] top class label for each prior
|
||||
|
||||
|
||||
def encode(matched, priors, variances):
|
||||
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||
we have matched (based on jaccard overlap) with the prior boxes.
|
||||
Args:
|
||||
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||
Shape: [num_priors, 4].
|
||||
priors: (tensor) Prior boxes in center-offset form
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||
"""
|
||||
|
||||
# dist b/t match center and prior's center
|
||||
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||
# encode variance
|
||||
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||
# match wh / prior wh
|
||||
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||
#g_wh = torch.log(g_wh) / variances[1]
|
||||
g_wh = torch.log(g_wh) / variances[1]
|
||||
# return target for smooth_l1_loss
|
||||
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||
|
||||
|
||||
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((
|
||||
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
"""Utility function for computing log_sum_exp while determining
|
||||
This will be used to determine unaveraged confidence loss across
|
||||
all examples in a batch.
|
||||
Args:
|
||||
x (Variable(tensor)): conf_preds from conf layers
|
||||
"""
|
||||
x_max = x.data.max()
|
||||
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
|
||||
|
||||
|
||||
# Original author: Francisco Massa:
|
||||
# https://github.com/fmassa/object-detection.torch
|
||||
# Ported to PyTorch by Max deGroot (02/01/2017)
|
||||
def nms(boxes, scores, overlap=0.5, top_k=200):
|
||||
"""Apply non-maximum suppression at test time to avoid detecting too many
|
||||
overlapping bounding boxes for a given object.
|
||||
Args:
|
||||
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
||||
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
||||
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
||||
top_k: (int) The Maximum number of box preds to consider.
|
||||
Return:
|
||||
The indices of the kept boxes with respect to num_priors.
|
||||
"""
|
||||
|
||||
keep = scores.new(scores.size(0)).zero_().long()
|
||||
if boxes.numel() == 0:
|
||||
return keep
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
area = torch.mul(x2 - x1, y2 - y1)
|
||||
v, idx = scores.sort(0) # sort in ascending order
|
||||
# I = I[v >= 0.01]
|
||||
idx = idx[-top_k:] # indices of the top-k largest vals
|
||||
xx1 = boxes.new()
|
||||
yy1 = boxes.new()
|
||||
xx2 = boxes.new()
|
||||
yy2 = boxes.new()
|
||||
w = boxes.new()
|
||||
h = boxes.new()
|
||||
|
||||
# keep = torch.Tensor()
|
||||
count = 0
|
||||
while idx.numel() > 0:
|
||||
i = idx[-1] # index of current largest val
|
||||
# keep.append(i)
|
||||
keep[count] = i
|
||||
count += 1
|
||||
if idx.size(0) == 1:
|
||||
break
|
||||
idx = idx[:-1] # remove kept element from view
|
||||
# load bboxes of next highest vals
|
||||
torch.index_select(x1, 0, idx, out=xx1)
|
||||
torch.index_select(y1, 0, idx, out=yy1)
|
||||
torch.index_select(x2, 0, idx, out=xx2)
|
||||
torch.index_select(y2, 0, idx, out=yy2)
|
||||
# store element-wise max with next highest score
|
||||
xx1 = torch.clamp(xx1, min=x1[i])
|
||||
yy1 = torch.clamp(yy1, min=y1[i])
|
||||
xx2 = torch.clamp(xx2, max=x2[i])
|
||||
yy2 = torch.clamp(yy2, max=y2[i])
|
||||
w.resize_as_(xx2)
|
||||
h.resize_as_(yy2)
|
||||
w = xx2 - xx1
|
||||
h = yy2 - yy1
|
||||
# check sizes of xx1 and xx2.. after each iteration
|
||||
w = torch.clamp(w, min=0.0)
|
||||
h = torch.clamp(h, min=0.0)
|
||||
inter = w * h
|
||||
# IoU = i / (area(a) + area(b) - i)
|
||||
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
||||
union = (rem_areas - inter) + area[i]
|
||||
IoU = inter / union # store result in iou
|
||||
# keep only elements with an IoU <= overlap
|
||||
idx = idx[IoU.le(overlap)]
|
||||
return keep, count
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .prior_box import PriorBox
|
||||
from .detection import detect_function
|
||||
|
||||
__all__=['detect_function','PriorBox']
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
|
||||
from ..bbox_utils import decode, nms
|
||||
|
||||
|
||||
|
||||
def detect_function(cfg, loc_data, conf_data, prior_data):
|
||||
"""
|
||||
Args:
|
||||
loc_data: (tensor) Loc preds from loc layers
|
||||
Shape: [batch,num_priors*4]
|
||||
conf_data: (tensor) Shape: Conf preds from conf layers
|
||||
Shape: [batch*num_priors,num_classes]
|
||||
prior_data: (tensor) Prior boxes and variances from priorbox layers
|
||||
Shape: [1,num_priors,4]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
num = loc_data.size(0)
|
||||
num_priors = prior_data.size(0)
|
||||
|
||||
conf_preds = conf_data.view(
|
||||
num, num_priors, cfg.NUM_CLASSES).transpose(2, 1)
|
||||
batch_priors = prior_data.view(-1, num_priors,
|
||||
4).expand(num, num_priors, 4)
|
||||
batch_priors = batch_priors.contiguous().view(-1, 4)
|
||||
|
||||
decoded_boxes = decode(loc_data.view(-1, 4),
|
||||
batch_priors, cfg.VARIANCE)
|
||||
decoded_boxes = decoded_boxes.view(num, num_priors, 4)
|
||||
|
||||
output = torch.zeros(num, cfg.NUM_CLASSES, cfg.TOP_K, 5)
|
||||
|
||||
for i in range(num):
|
||||
boxes = decoded_boxes[i].clone()
|
||||
conf_scores = conf_preds[i].clone()
|
||||
|
||||
for cl in range(1, cfg.NUM_CLASSES):
|
||||
c_mask = conf_scores[cl].gt(cfg.CONF_THRESH)
|
||||
scores = conf_scores[cl][c_mask]
|
||||
|
||||
if scores.dim() == 0:
|
||||
continue
|
||||
l_mask = c_mask.unsqueeze(1).expand_as(boxes)
|
||||
boxes_ = boxes[l_mask].view(-1, 4)
|
||||
ids, count = nms(
|
||||
boxes_, scores, cfg.NMS_THRESH, cfg.NMS_TOP_K)
|
||||
count = count if count < cfg.TOP_K else cfg.TOP_K
|
||||
|
||||
output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1),
|
||||
boxes_[ids[:count]]), 1)
|
||||
|
||||
return output
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from itertools import product as product
|
||||
import math
|
||||
|
||||
|
||||
class PriorBox(object):
|
||||
"""Compute priorbox coordinates in center-offset form for each source
|
||||
feature map.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, feature_maps,cfg):
|
||||
super(PriorBox, self).__init__()
|
||||
self.imh = input_size[0]
|
||||
self.imw = input_size[1]
|
||||
|
||||
# number of priors for feature map location (either 4 or 6)
|
||||
self.variance = cfg.VARIANCE or [0.1]
|
||||
#self.feature_maps = cfg.FEATURE_MAPS
|
||||
self.min_sizes = cfg.ANCHOR_SIZES
|
||||
self.steps = cfg.STEPS
|
||||
self.clip = cfg.CLIP
|
||||
for v in self.variance:
|
||||
if v <= 0:
|
||||
raise ValueError('Variances must be greater than 0')
|
||||
self.feature_maps = feature_maps
|
||||
|
||||
|
||||
def forward(self):
|
||||
mean = []
|
||||
for k in range(len(self.feature_maps)):
|
||||
feath = self.feature_maps[k][0]
|
||||
featw = self.feature_maps[k][1]
|
||||
for i, j in product(range(feath), range(featw)):
|
||||
f_kw = self.imw / self.steps[k]
|
||||
f_kh = self.imh / self.steps[k]
|
||||
|
||||
cx = (j + 0.5) / f_kw
|
||||
cy = (i + 0.5) / f_kh
|
||||
|
||||
s_kw = self.min_sizes[k] / self.imw
|
||||
s_kh = self.min_sizes[k] / self.imh
|
||||
|
||||
mean += [cx, cy, s_kw, s_kh]
|
||||
|
||||
output = torch.Tensor(mean).view(-1, 4)
|
||||
if self.clip:
|
||||
output.clamp_(max=1, min=0)
|
||||
return output
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .l2norm import L2Norm
|
||||
from .multibox_loss import MultiBoxLoss
|
||||
|
||||
__all__ = ['L2Norm', 'MultiBoxLoss']
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self,n_channels, scale):
|
||||
super(L2Norm,self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.gamma = scale or None
|
||||
self.eps = 1e-10
|
||||
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.constant_(self.weight,self.gamma)
|
||||
|
||||
def forward(self, x):
|
||||
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps
|
||||
#x /= norm
|
||||
x = torch.div(x,norm)
|
||||
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
from ..bbox_utils import match, log_sum_exp, match_ssd
|
||||
|
||||
|
||||
class MultiBoxLoss(nn.Module):
|
||||
"""SSD Weighted Loss Function
|
||||
Compute Targets:
|
||||
1) Produce Confidence Target Indices by matching ground truth boxes
|
||||
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
||||
(default threshold: 0.5).
|
||||
2) Produce localization target by 'encoding' variance into offsets of ground
|
||||
truth boxes and their matched 'priorboxes'.
|
||||
3) Hard negative mining to filter the excessive number of negative examples
|
||||
that comes with using a large number of default bounding boxes.
|
||||
(default negative:positive ratio 3:1)
|
||||
Objective Loss:
|
||||
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
||||
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
||||
weighted by α which is set to 1 by cross val.
|
||||
Args:
|
||||
c: class confidences,
|
||||
l: predicted boxes,
|
||||
g: ground truth boxes
|
||||
N: number of matched default boxes
|
||||
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, dataset, use_gpu=True):
|
||||
super(MultiBoxLoss, self).__init__()
|
||||
self.use_gpu = use_gpu
|
||||
self.num_classes = cfg.NUM_CLASSES
|
||||
self.negpos_ratio = cfg.NEG_POS_RATIOS
|
||||
self.variance = cfg.VARIANCE
|
||||
self.dataset = dataset
|
||||
|
||||
self.threshold = cfg.FACE.OVERLAP_THRESH
|
||||
self.match = match
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
"""Multibox Loss
|
||||
Args:
|
||||
predictions (tuple): A tuple containing loc preds, conf preds,
|
||||
and prior boxes from SSD net.
|
||||
conf shape: torch.size(batch_size,num_priors,num_classes)
|
||||
loc shape: torch.size(batch_size,num_priors,4)
|
||||
priors shape: torch.size(num_priors,4)
|
||||
|
||||
targets (tensor): Ground truth boxes and labels for a batch,
|
||||
shape: [batch_size,num_objs,5] (last idx is the label).
|
||||
"""
|
||||
loc_data, conf_data, priors = predictions
|
||||
num = loc_data.size(0)
|
||||
priors = priors[:loc_data.size(1), :]
|
||||
num_priors = (priors.size(0))
|
||||
num_classes = self.num_classes
|
||||
|
||||
# match priors (default boxes) and ground truth boxes
|
||||
loc_t = torch.Tensor(num, num_priors, 4)
|
||||
conf_t = torch.LongTensor(num, num_priors)
|
||||
for idx in range(num):
|
||||
truths = targets[idx][:, :-1].data
|
||||
labels = targets[idx][:, -1].data
|
||||
defaults = priors.data
|
||||
self.match(self.threshold, truths, defaults, self.variance, labels,
|
||||
loc_t, conf_t, idx)
|
||||
if self.use_gpu:
|
||||
loc_t = loc_t.cuda()
|
||||
conf_t = conf_t.cuda()
|
||||
# wrap targets
|
||||
loc_t = Variable(loc_t, requires_grad=False)
|
||||
conf_t = Variable(conf_t, requires_grad=False)
|
||||
|
||||
pos = conf_t > 0
|
||||
num_pos = pos.sum(dim=1, keepdim=True)
|
||||
# Localization Loss (Smooth L1)
|
||||
# Shape: [batch,num_priors,4]
|
||||
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
||||
loc_p = loc_data[pos_idx].view(-1, 4)
|
||||
loc_t = loc_t[pos_idx].view(-1, 4)
|
||||
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
|
||||
# print(loc_p)
|
||||
# Compute max conf across batch for hard negative mining
|
||||
batch_conf = conf_data.view(-1, self.num_classes)
|
||||
loss_c = log_sum_exp(batch_conf) - \
|
||||
batch_conf.gather(1, conf_t.view(-1, 1))
|
||||
|
||||
# Hard Negative Mining
|
||||
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
||||
loss_c = loss_c.view(num, -1)
|
||||
_, loss_idx = loss_c.sort(1, descending=True)
|
||||
_, idx_rank = loss_idx.sort(1)
|
||||
num_pos = pos.long().sum(1, keepdim=True)
|
||||
num_neg = torch.clamp(self.negpos_ratio *
|
||||
num_pos, max=pos.size(1) - 1)
|
||||
neg = idx_rank < num_neg.expand_as(idx_rank)
|
||||
|
||||
# Confidence Loss Including Positive and Negative Examples
|
||||
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
||||
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
||||
conf_p = conf_data[(pos_idx + neg_idx).gt(0)
|
||||
].view(-1, self.num_classes)
|
||||
targets_weighted = conf_t[(pos + neg).gt(0)]
|
||||
loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
|
||||
|
||||
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
||||
N = num_pos.data.sum() if num_pos.data.sum() > 0 else num
|
||||
loss_l /= N
|
||||
loss_c /= N
|
||||
return loss_l, loss_c
|
|
@ -0,0 +1,382 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from layers import *
|
||||
from data.config import cfg
|
||||
import numpy as np
|
||||
|
||||
from edgeml_pytorch.graph.rnnpool import *
|
||||
|
||||
class S3FD(nn.Module):
|
||||
"""Single Shot Multibox Architecture
|
||||
The network is composed of a base VGG network followed by the
|
||||
added multibox conv layers. Each multibox layer branches into
|
||||
1) conv2d for class conf scores
|
||||
2) conv2d for localization predictions
|
||||
3) associated priorbox layer to produce default bounding
|
||||
boxes specific to the layer's feature map size.
|
||||
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
||||
|
||||
Args:
|
||||
phase: (string) Can be "test" or "train"
|
||||
size: input image size
|
||||
base: VGG16 layers for input, size of either 300 or 500
|
||||
extras: extra layers that feed to multibox loc and conf layers
|
||||
head: "multibox head" consists of loc and conf conv layers
|
||||
"""
|
||||
|
||||
def __init__(self, phase, base, head, num_classes):
|
||||
super(S3FD, self).__init__()
|
||||
self.phase = phase
|
||||
self.num_classes = num_classes
|
||||
'''
|
||||
self.priorbox = PriorBox(size,cfg)
|
||||
self.priors = Variable(self.priorbox.forward(), volatile=True)
|
||||
'''
|
||||
# SSD network
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=(8,8),stride=(4,4))
|
||||
|
||||
self.rnn_model = RNNPool(8, 8, 16, 16, 3)
|
||||
|
||||
self.mob = nn.ModuleList(base)
|
||||
# Layer learns to scale the l2 normalized features from conv4_3
|
||||
self.L2Norm3_3 = L2Norm(24, 10)
|
||||
self.L2Norm4_3 = L2Norm(32, 8)
|
||||
self.L2Norm5_3 = L2Norm(64, 5)
|
||||
|
||||
|
||||
self.loc = nn.ModuleList(head[0])
|
||||
self.conf = nn.ModuleList(head[1])
|
||||
|
||||
if self.phase == 'test':
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
# self.detect = Detect(cfg)
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies network layers and ops on input image(s) x.
|
||||
|
||||
Args:
|
||||
x: input image or batch of images. Shape: [batch,3,300,300].
|
||||
|
||||
Return:
|
||||
Depending on phase:
|
||||
test:
|
||||
Variable(tensor) of output class label predictions,
|
||||
confidence score, and corresponding location predictions for
|
||||
each object detected. Shape: [batch,topk,7]
|
||||
|
||||
train:
|
||||
list of concat outputs from:
|
||||
1: confidence layers, Shape: [batch*num_priors,num_classes]
|
||||
2: localization layers, Shape: [batch,num_priors*4]
|
||||
3: priorbox layers, Shape: [2,num_priors*4]
|
||||
"""
|
||||
size = x.size()[2:]
|
||||
batch_size = x.shape[0]
|
||||
sources = list()
|
||||
loc = list()
|
||||
conf = list()
|
||||
|
||||
patches = self.unfold(x)
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,3,8,8))
|
||||
|
||||
output_x = int((x.shape[2]-8)/4 + 1)
|
||||
output_y = int((x.shape[3]-8)/4 + 1)
|
||||
|
||||
rnnX = self.rnn_model(patches, int(batch_size)*output_x*output_y)
|
||||
|
||||
x = torch.stack(torch.split(rnnX, split_size_or_sections=int(batch_size), dim=0),dim=2)
|
||||
|
||||
x = F.fold(x, kernel_size=(1,1), output_size=(output_x,output_y))
|
||||
|
||||
x = F.pad(x, (0,1,0,1), mode='replicate')
|
||||
|
||||
|
||||
|
||||
for k in range(2):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm3_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(2, 5):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm4_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(5, 9):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm5_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(9, 12):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
for k in range(12, 14):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
for k in range(14, 15):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
|
||||
|
||||
# apply multibox head to source layers
|
||||
|
||||
loc_x = self.loc[0](sources[0])
|
||||
conf_x = self.conf[0](sources[0])
|
||||
|
||||
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
|
||||
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
|
||||
|
||||
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
|
||||
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
for i in range(1, len(sources)):
|
||||
x = sources[i]
|
||||
conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous())
|
||||
loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
|
||||
features_maps = []
|
||||
for i in range(len(loc)):
|
||||
feat = []
|
||||
feat += [loc[i].size(1), loc[i].size(2)]
|
||||
features_maps += [feat]
|
||||
|
||||
self.priorbox = PriorBox(size, features_maps, cfg)
|
||||
self.priors = self.priorbox.forward()
|
||||
|
||||
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
||||
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
||||
|
||||
|
||||
if self.phase == 'test':
|
||||
output = detect_function(
|
||||
loc.view(loc.size(0), -1, 4), # loc preds
|
||||
self.softmax(conf.view(conf.size(0), -1,
|
||||
self.num_classes)), # conf preds
|
||||
self.priors.type(type(x.data)) # default boxes
|
||||
)
|
||||
|
||||
else:
|
||||
output = (
|
||||
loc.view(loc.size(0), -1, 4),
|
||||
conf.view(conf.size(0), -1, self.num_classes),
|
||||
self.priors
|
||||
)
|
||||
return output
|
||||
|
||||
def load_weights(self, base_file):
|
||||
other, ext = os.path.splitext(base_file)
|
||||
if ext == '.pkl' or '.pth':
|
||||
print('Loading weights into state dict...')
|
||||
mdata = torch.load(base_file,
|
||||
map_location=lambda storage, loc: storage)
|
||||
weights = mdata['weight']
|
||||
epoch = mdata['epoch']
|
||||
self.load_state_dict(weights)
|
||||
print('Finished!')
|
||||
else:
|
||||
print('Sorry only .pth and .pkl files supported.')
|
||||
return epoch
|
||||
|
||||
def xavier(self, param):
|
||||
init.xavier_uniform(param)
|
||||
|
||||
def weights_init(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
self.xavier(m.weight.data)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 64
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
# [1, 16, 1, 1],
|
||||
[1, 24, 1, 1],
|
||||
[6, 24, 1, 1],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 2],
|
||||
[6, 160, 2, 2],
|
||||
[6, 320, 1, 2],
|
||||
]
|
||||
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.layers = []
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
self.layers.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
|
||||
def multibox(mobilenet, num_classes):
|
||||
loc_layers = []
|
||||
conf_layers = []
|
||||
|
||||
loc_layers += [nn.Conv2d(24, 4,
|
||||
kernel_size=5, padding=2)]
|
||||
conf_layers += [nn.Conv2d(24,
|
||||
3 + (num_classes-1), kernel_size=5, padding=2)]
|
||||
|
||||
loc_layers += [nn.Conv2d(32,
|
||||
4, kernel_size=5, padding=2)]
|
||||
conf_layers += [nn.Conv2d(32,
|
||||
num_classes, kernel_size=5, padding=2)]
|
||||
|
||||
loc_layers += [nn.Conv2d(64,
|
||||
4, kernel_size=5, padding=2)]
|
||||
conf_layers += [nn.Conv2d(64,
|
||||
num_classes, kernel_size=5, padding=2)]
|
||||
|
||||
loc_layers += [nn.Conv2d(96,
|
||||
4, kernel_size=1, padding=0)]
|
||||
conf_layers += [nn.Conv2d(96,
|
||||
num_classes, kernel_size=1, padding=0)]
|
||||
|
||||
loc_layers += [nn.Conv2d(160,
|
||||
4, kernel_size=1, padding=0)]
|
||||
conf_layers += [nn.Conv2d(160,
|
||||
num_classes, kernel_size=1, padding=0)]
|
||||
|
||||
loc_layers += [nn.Conv2d(320,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(320,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
|
||||
|
||||
return mobilenet, (loc_layers, conf_layers)
|
||||
|
||||
|
||||
def build_s3fd(phase, num_classes=2):
|
||||
base_, head_ = multibox(
|
||||
MobileNetV2().layers, num_classes)
|
||||
|
||||
return S3FD(phase, base_, head_, num_classes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = build_s3fd('train', num_classes=2)
|
||||
inputs = Variable(torch.randn(4, 3, 640, 640))
|
||||
output = net(inputs)
|
||||
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from layers import *
|
||||
from data.config_qvga import cfg
|
||||
import numpy as np
|
||||
|
||||
from edgeml_pytorch.graph.rnnpool import *
|
||||
|
||||
class S3FD(nn.Module):
|
||||
"""Single Shot Multibox Architecture
|
||||
The network is composed of a base VGG network followed by the
|
||||
added multibox conv layers. Each multibox layer branches into
|
||||
1) conv2d for class conf scores
|
||||
2) conv2d for localization predictions
|
||||
3) associated priorbox layer to produce default bounding
|
||||
boxes specific to the layer's feature map size.
|
||||
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
||||
Args:
|
||||
phase: (string) Can be "test" or "train"
|
||||
size: input image size
|
||||
base: VGG16 layers for input, size of either 300 or 500
|
||||
extras: extra layers that feed to multibox loc and conf layers
|
||||
head: "multibox head" consists of loc and conf conv layers
|
||||
"""
|
||||
|
||||
def __init__(self, phase, base, head, num_classes):
|
||||
super(S3FD, self).__init__()
|
||||
self.phase = phase
|
||||
self.num_classes = num_classes
|
||||
'''
|
||||
self.priorbox = PriorBox(size,cfg)
|
||||
self.priors = Variable(self.priorbox.forward(), volatile=True)
|
||||
'''
|
||||
# SSD network
|
||||
self.conv = ConvBNReLU(1, 4, stride=2)
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=(8,8),stride=(4,4))
|
||||
|
||||
self.rnn_model = RNNPool(8, 8, 16, 16, 4)#num_init_features)
|
||||
|
||||
self.mob = nn.ModuleList(base)
|
||||
# Layer learns to scale the l2 normalized features from conv4_3
|
||||
self.L2Norm3_3 = L2Norm(32, 10)
|
||||
self.L2Norm4_3 = L2Norm(32, 8)
|
||||
self.L2Norm5_3 = L2Norm(96, 5)
|
||||
|
||||
|
||||
self.loc = nn.ModuleList(head[0])
|
||||
self.conf = nn.ModuleList(head[1])
|
||||
|
||||
|
||||
if self.phase == 'test':
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies network layers and ops on input image(s) x.
|
||||
Args:
|
||||
x: input image or batch of images. Shape: [batch,3,300,300].
|
||||
Return:
|
||||
Depending on phase:
|
||||
test:
|
||||
Variable(tensor) of output class label predictions,
|
||||
confidence score, and corresponding location predictions for
|
||||
each object detected. Shape: [batch,topk,7]
|
||||
train:
|
||||
list of concat outputs from:
|
||||
1: confidence layers, Shape: [batch*num_priors,num_classes]
|
||||
2: localization layers, Shape: [batch,num_priors*4]
|
||||
3: priorbox layers, Shape: [2,num_priors*4]
|
||||
"""
|
||||
size = x.size()[2:]
|
||||
batch_size = x.shape[0]
|
||||
sources = list()
|
||||
loc = list()
|
||||
conf = list()
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
patches = self.unfold(x)
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,4,8,8))
|
||||
|
||||
output_x = int((x.shape[2]-8)/4 + 1)
|
||||
output_y = int((x.shape[3]-8)/4 + 1)
|
||||
|
||||
rnnX = self.rnn_model(patches, int(batch_size)*output_x*output_y)
|
||||
|
||||
x = torch.stack(torch.split(rnnX, split_size_or_sections=int(batch_size), dim=0),dim=2)
|
||||
|
||||
x = F.fold(x, kernel_size=(1,1), output_size=(output_x,output_y))
|
||||
|
||||
x = F.pad(x, (0,1,0,1), mode='replicate')
|
||||
|
||||
for k in range(4):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm3_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(4, 8):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm4_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(8, 11):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm5_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(11, 14):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
|
||||
# apply multibox head to source layers
|
||||
|
||||
loc_x = self.loc[0](sources[0])
|
||||
conf_x = self.conf[0](sources[0])
|
||||
|
||||
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
|
||||
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
|
||||
|
||||
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
|
||||
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
for i in range(1, len(sources)):
|
||||
x = sources[i]
|
||||
conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous())
|
||||
loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
|
||||
features_maps = []
|
||||
for i in range(len(loc)):
|
||||
feat = []
|
||||
feat += [loc[i].size(1), loc[i].size(2)]
|
||||
features_maps += [feat]
|
||||
|
||||
self.priorbox = PriorBox(size, features_maps, cfg)
|
||||
|
||||
self.priors = self.priorbox.forward()
|
||||
|
||||
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
||||
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
||||
|
||||
|
||||
if self.phase == 'test':
|
||||
output = detect_function(cfg,
|
||||
loc.view(loc.size(0), -1, 4), # loc preds
|
||||
self.softmax(conf.view(conf.size(0), -1,
|
||||
self.num_classes)), # conf preds
|
||||
self.priors.type(type(x.data)) # default boxes
|
||||
)
|
||||
|
||||
else:
|
||||
output = (
|
||||
loc.view(loc.size(0), -1, 4),
|
||||
conf.view(conf.size(0), -1, self.num_classes),
|
||||
self.priors
|
||||
)
|
||||
return output
|
||||
|
||||
def load_weights(self, base_file):
|
||||
other, ext = os.path.splitext(base_file)
|
||||
if ext == '.pkl' or '.pth':
|
||||
print('Loading weights into state dict...')
|
||||
mdata = torch.load(base_file,
|
||||
map_location=lambda storage, loc: storage)
|
||||
weights = mdata['weight']
|
||||
epoch = mdata['epoch']
|
||||
self.load_state_dict(weights)
|
||||
print('Finished!')
|
||||
else:
|
||||
print('Sorry only .pth and .pkl files supported.')
|
||||
return epoch
|
||||
|
||||
def xavier(self, param):
|
||||
init.xavier_uniform(param)
|
||||
|
||||
def weights_init(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
self.xavier(m.weight.data)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 64
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[2, 32, 4, 1],
|
||||
[2, 32, 4, 1],
|
||||
[2, 96, 3, 2],
|
||||
[2, 128, 3, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.layers = []
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
self.layers.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
|
||||
|
||||
def multibox(mobilenet, num_classes):
|
||||
loc_layers = []
|
||||
conf_layers = []
|
||||
|
||||
loc_layers += [nn.Conv2d(32, 4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(32, 3 + (num_classes-1), kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(32, 4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(32, num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(96, 4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(96, num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(128, 4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(128, num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
|
||||
return mobilenet, (loc_layers, conf_layers)
|
||||
|
||||
|
||||
def build_s3fd(phase, num_classes=2):
|
||||
base_, head_ = multibox(
|
||||
MobileNetV2().layers, num_classes)
|
||||
|
||||
return S3FD(phase, base_, head_, num_classes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = build_s3fd('train', num_classes=2)
|
||||
inputs = Variable(torch.randn(4, 1, 320, 320))
|
||||
output = net(inputs)
|
|
@ -0,0 +1,375 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from layers import *
|
||||
from data.config import cfg
|
||||
import numpy as np
|
||||
|
||||
from edgeml_pytorch.graph.rnnpool import *
|
||||
|
||||
class S3FD(nn.Module):
|
||||
"""Single Shot Multibox Architecture
|
||||
The network is composed of a base VGG network followed by the
|
||||
added multibox conv layers. Each multibox layer branches into
|
||||
1) conv2d for class conf scores
|
||||
2) conv2d for localization predictions
|
||||
3) associated priorbox layer to produce default bounding
|
||||
boxes specific to the layer's feature map size.
|
||||
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
||||
Args:
|
||||
phase: (string) Can be "test" or "train"
|
||||
size: input image size
|
||||
base: VGG16 layers for input, size of either 300 or 500
|
||||
extras: extra layers that feed to multibox loc and conf layers
|
||||
head: "multibox head" consists of loc and conf conv layers
|
||||
"""
|
||||
|
||||
def __init__(self, phase, base, head, num_classes):
|
||||
super(S3FD, self).__init__()
|
||||
self.phase = phase
|
||||
self.num_classes = num_classes
|
||||
'''
|
||||
self.priorbox = PriorBox(size,cfg)
|
||||
self.priors = Variable(self.priorbox.forward(), volatile=True)
|
||||
'''
|
||||
# SSD network
|
||||
|
||||
self.conv_top = nn.Sequential(ConvBNReLU(3, 4, kernel_size=3, stride=2), ConvBNReLU(4, 4, kernel_size=3))
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=(8,8),stride=(4,4))
|
||||
|
||||
self.rnn_model = RNNPool(8, 8, 8, 8, 4)
|
||||
|
||||
self.mob = nn.ModuleList(base)
|
||||
# Layer learns to scale the l2 normalized features from conv4_3
|
||||
self.L2Norm3_3 = L2Norm(4, 10)
|
||||
self.L2Norm4_3 = L2Norm(16, 8)
|
||||
self.L2Norm5_3 = L2Norm(24, 5)
|
||||
|
||||
|
||||
self.loc = nn.ModuleList(head[0])
|
||||
self.conf = nn.ModuleList(head[1])
|
||||
|
||||
if self.phase == 'test':
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies network layers and ops on input image(s) x.
|
||||
Args:
|
||||
x: input image or batch of images. Shape: [batch,3,300,300].
|
||||
Return:
|
||||
Depending on phase:
|
||||
test:
|
||||
Variable(tensor) of output class label predictions,
|
||||
confidence score, and corresponding location predictions for
|
||||
each object detected. Shape: [batch,topk,7]
|
||||
train:
|
||||
list of concat outputs from:
|
||||
1: confidence layers, Shape: [batch*num_priors,num_classes]
|
||||
2: localization layers, Shape: [batch,num_priors*4]
|
||||
3: priorbox layers, Shape: [2,num_priors*4]
|
||||
"""
|
||||
size = x.size()[2:]
|
||||
batch_size = x.shape[0]
|
||||
sources = list()
|
||||
loc = list()
|
||||
conf = list()
|
||||
|
||||
x = self.conv_top(x)
|
||||
|
||||
s = self.L2Norm3_3(x)
|
||||
sources.append(s)
|
||||
|
||||
patches = self.unfold(x)
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,4,8,8))
|
||||
|
||||
output_x = int((x.shape[2]-8)/4 + 1)
|
||||
output_y = int((x.shape[3]-8)/4 + 1)
|
||||
|
||||
rnnX = self.rnn_model(patches, int(batch_size)*output_x*output_y)
|
||||
|
||||
x = torch.stack(torch.split(rnnX, split_size_or_sections=int(batch_size), dim=0),dim=2)
|
||||
|
||||
x = F.fold(x, kernel_size=(1,1), output_size=(output_x,output_y))
|
||||
|
||||
x = F.pad(x, (0,1,0,1), mode='replicate')
|
||||
|
||||
for k in range(4):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm4_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(4, 8):
|
||||
x = self.mob[k](x)
|
||||
|
||||
s = self.L2Norm5_3(x)
|
||||
sources.append(s)
|
||||
|
||||
for k in range(8, 10):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
for k in range(10, 11):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
for k in range(11, 12):
|
||||
x = self.mob[k](x)
|
||||
sources.append(x)
|
||||
|
||||
|
||||
|
||||
# apply multibox head to source layers
|
||||
|
||||
loc_x = self.loc[0](sources[0])
|
||||
conf_x = self.conf[0](sources[0])
|
||||
|
||||
loc_x = self.loc[1](loc_x)
|
||||
conf_x = self.conf[1](conf_x)
|
||||
|
||||
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
|
||||
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
|
||||
|
||||
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
|
||||
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
for i in range(1, len(sources)):
|
||||
x = sources[i]
|
||||
conf.append(self.conf[i+1](x).permute(0, 2, 3, 1).contiguous())
|
||||
loc.append(self.loc[i+1](x).permute(0, 2, 3, 1).contiguous())
|
||||
|
||||
|
||||
features_maps = []
|
||||
for i in range(len(loc)):
|
||||
feat = []
|
||||
feat += [loc[i].size(1), loc[i].size(2)]
|
||||
features_maps += [feat]
|
||||
|
||||
self.priorbox = PriorBox(size, features_maps, cfg)
|
||||
self.priors = self.priorbox.forward()
|
||||
|
||||
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
||||
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
||||
|
||||
|
||||
if self.phase == 'test':
|
||||
output = detect_function(cfg,
|
||||
loc.view(loc.size(0), -1, 4), # loc preds
|
||||
self.softmax(conf.view(conf.size(0), -1,
|
||||
self.num_classes)), # conf preds
|
||||
self.priors.type(type(x.data)) # default boxes
|
||||
)
|
||||
|
||||
else:
|
||||
output = (
|
||||
loc.view(loc.size(0), -1, 4),
|
||||
conf.view(conf.size(0), -1, self.num_classes),
|
||||
self.priors
|
||||
)
|
||||
return output
|
||||
|
||||
def load_weights(self, base_file):
|
||||
other, ext = os.path.splitext(base_file)
|
||||
if ext == '.pkl' or '.pth':
|
||||
print('Loading weights into state dict...')
|
||||
mdata = torch.load(base_file,
|
||||
map_location=lambda storage, loc: storage)
|
||||
weights = mdata['weight']
|
||||
epoch = mdata['epoch']
|
||||
self.load_state_dict(weights)
|
||||
print('Finished!')
|
||||
else:
|
||||
print('Sorry only .pth and .pkl files supported.')
|
||||
return epoch
|
||||
|
||||
def xavier(self, param):
|
||||
init.xavier_uniform(param)
|
||||
|
||||
def weights_init(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
self.xavier(m.weight.data)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[2, 16, 4, 1],
|
||||
[2, 24, 4, 2],
|
||||
[2, 32, 2, 2],
|
||||
[2, 64, 1, 2],
|
||||
[2, 96, 1, 2],
|
||||
]
|
||||
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.layers = []
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
self.layers.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
|
||||
|
||||
def multibox(mobilenet, num_classes):
|
||||
loc_layers = []
|
||||
conf_layers = []
|
||||
|
||||
loc_layers += nn.Sequential(ConvBNReLU(4, 8, kernel_size=3, stride=2),
|
||||
nn.Conv2d(8, 4, kernel_size=3, padding=1))
|
||||
conf_layers += nn.Sequential(ConvBNReLU(4, 8, kernel_size=3, stride=2),
|
||||
nn.Conv2d(8, 3 + (num_classes-1), kernel_size=3, padding=1))
|
||||
|
||||
loc_layers += [nn.Conv2d(16,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(16,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(24,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(24,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(32,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(32,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(64,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(64,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
loc_layers += [nn.Conv2d(96,
|
||||
4, kernel_size=3, padding=1)]
|
||||
conf_layers += [nn.Conv2d(96,
|
||||
num_classes, kernel_size=3, padding=1)]
|
||||
|
||||
|
||||
|
||||
return mobilenet, (loc_layers, conf_layers)
|
||||
|
||||
|
||||
def build_s3fd(phase, num_classes=2):
|
||||
base_, head_ = multibox(
|
||||
MobileNetV2().layers, num_classes)
|
||||
|
||||
return S3FD(phase, base_, head_, num_classes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = build_s3fd('train', num_classes=2)
|
||||
inputs = Variable(torch.randn(4, 3, 640, 640))
|
||||
output = net(inputs)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from .RPool_Face_C import *
|
||||
from .RPool_Face_Quant import *
|
|
@ -0,0 +1,86 @@
|
|||
## This code is built on https://github.com/yxlijun/S3FD.pytorch
|
||||
#-*- coding:utf-8 -*-
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import os
|
||||
from data.config import cfg
|
||||
import cv2
|
||||
|
||||
WIDER_ROOT = os.path.join(cfg.HOME, 'WIDER_FACE')
|
||||
train_list_file = os.path.join(WIDER_ROOT, 'wider_face_split',
|
||||
'wider_face_train_bbx_gt.txt')
|
||||
val_list_file = os.path.join(WIDER_ROOT, 'wider_face_split',
|
||||
'wider_face_val_bbx_gt.txt')
|
||||
|
||||
WIDER_TRAIN = os.path.join(WIDER_ROOT, 'WIDER_train', 'images')
|
||||
WIDER_VAL = os.path.join(WIDER_ROOT, 'WIDER_val', 'images')
|
||||
|
||||
|
||||
def parse_wider_file(root, file):
|
||||
with open(file, 'r') as fr:
|
||||
lines = fr.readlines()
|
||||
face_count = []
|
||||
img_paths = []
|
||||
face_loc = []
|
||||
img_faces = []
|
||||
count = 0
|
||||
flag = False
|
||||
for k, line in enumerate(lines):
|
||||
line = line.strip().strip('\n')
|
||||
if count > 0:
|
||||
line = line.split(' ')
|
||||
count -= 1
|
||||
loc = [int(line[0]), int(line[1]), int(line[2]), int(line[3])]
|
||||
face_loc += [loc]
|
||||
if flag:
|
||||
face_count += [int(line)]
|
||||
flag = False
|
||||
count = int(line)
|
||||
if 'jpg' in line:
|
||||
img_paths += [os.path.join(root, line)]
|
||||
flag = True
|
||||
|
||||
total_face = 0
|
||||
for k in face_count:
|
||||
face_ = []
|
||||
for x in range(total_face, total_face + k):
|
||||
face_.append(face_loc[x])
|
||||
img_faces += [face_]
|
||||
total_face += k
|
||||
return img_paths, img_faces
|
||||
|
||||
|
||||
def wider_data_file():
|
||||
img_paths, bbox = parse_wider_file(WIDER_TRAIN, train_list_file)
|
||||
fw = open(cfg.FACE.TRAIN_FILE, 'w')
|
||||
for index in range(len(img_paths)):
|
||||
path = img_paths[index]
|
||||
boxes = bbox[index]
|
||||
fw.write(path)
|
||||
fw.write(' {}'.format(len(boxes)))
|
||||
for box in boxes:
|
||||
data = ' {} {} {} {} {}'.format(box[0], box[1], box[2], box[3], 1)
|
||||
fw.write(data)
|
||||
fw.write('\n')
|
||||
fw.close()
|
||||
|
||||
img_paths, bbox = parse_wider_file(WIDER_VAL, val_list_file)
|
||||
fw = open(cfg.FACE.VAL_FILE, 'w')
|
||||
for index in range(len(img_paths)):
|
||||
path = img_paths[index]
|
||||
boxes = bbox[index]
|
||||
fw.write(path)
|
||||
fw.write(' {}'.format(len(boxes)))
|
||||
for box in boxes:
|
||||
data = ' {} {} {} {} {}'.format(box[0], box[1], box[2], box[3], 1)
|
||||
fw.write(data)
|
||||
fw.write('\n')
|
||||
fw.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
wider_data_file()
|
|
@ -0,0 +1,10 @@
|
|||
Cython==0.29.15
|
||||
easydict==1.9
|
||||
importlib-metadata==1.5.0
|
||||
matplotlib==3.2.1
|
||||
opencv-python-headless==4.2.0.32
|
||||
PyYAML==3.12
|
||||
scikit-image==0.15.0
|
||||
tensorboard==1.14.0
|
||||
tensorboardX==1.9
|
||||
tqdm==4.36.1
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from data import *
|
||||
from layers.modules import MultiBoxLoss
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.init as init
|
||||
import torch.utils.data as data
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from data.choose_config import cfg
|
||||
cfg = cfg.cfg
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("yes", "true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='S3FD face Detector Training With Pytorch')
|
||||
train_set = parser.add_mutually_exclusive_group()
|
||||
parser.add_argument('--dataset',
|
||||
default='face',
|
||||
choices=['hand', 'face', 'head'],
|
||||
help='Train target')
|
||||
parser.add_argument('--basenet',
|
||||
default='vgg16_reducedfc.pth',
|
||||
help='Pretrained base model')
|
||||
parser.add_argument('--batch_size',
|
||||
default=16, type=int,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--resume',
|
||||
default=None, type=str,
|
||||
help='Checkpoint state_dict file to resume training from')
|
||||
parser.add_argument('--model_arch',
|
||||
default='RPool_Face_C', type=str,
|
||||
choices=['RPool_Face_C', 'RPool_Face_Quant', 'RPool_Face_QVGA_monochrome'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--num_workers',
|
||||
default=128, type=int,
|
||||
help='Number of workers used in dataloading')
|
||||
parser.add_argument('--cuda',
|
||||
default=True, type=str2bool,
|
||||
help='Use CUDA to train model')
|
||||
parser.add_argument('--lr', '--learning-rate',
|
||||
default=1e-2, type=float,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--momentum',
|
||||
default=0.9, type=float,
|
||||
help='Momentum value for optim')
|
||||
parser.add_argument('--weight_decay',
|
||||
default=5e-4, type=float,
|
||||
help='Weight decay for SGD')
|
||||
parser.add_argument('--gamma',
|
||||
default=0.1, type=float,
|
||||
help='Gamma update for SGD')
|
||||
parser.add_argument('--multigpu',
|
||||
default=False, type=str2bool,
|
||||
help='Use mutil Gpu training')
|
||||
parser.add_argument('--save_folder',
|
||||
default='weights/',
|
||||
help='Directory for saving checkpoint models')
|
||||
parser.add_argument('--epochs',
|
||||
default=300, type=int,
|
||||
help='total epochs')
|
||||
parser.add_argument('--save_frequency',
|
||||
default=5000, type=int,
|
||||
help='iterations interval after which checkpoint is saved')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if args.cuda:
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
if not args.cuda:
|
||||
print("WARNING: It looks like you have a CUDA device, but aren't " +
|
||||
"using CUDA.\nRun with --cuda for optimal training speed.")
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
else:
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
|
||||
if not os.path.exists(args.save_folder):
|
||||
os.makedirs(args.save_folder)
|
||||
|
||||
|
||||
train_dataset = WIDERDetection(cfg.FACE.TRAIN_FILE, mode='train', mono_mode=cfg.IS_MONOCHROME)
|
||||
val_dataset = WIDERDetection(cfg.FACE.VAL_FILE, mode='val', mono_mode=cfg.IS_MONOCHROME)
|
||||
|
||||
train_loader = data.DataLoader(train_dataset, args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=detection_collate,
|
||||
pin_memory=True)
|
||||
|
||||
val_batchsize = args.batch_size // 2
|
||||
val_loader = data.DataLoader(val_dataset, val_batchsize,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=detection_collate,
|
||||
pin_memory=True)
|
||||
|
||||
min_loss = np.inf
|
||||
start_epoch = 0
|
||||
|
||||
module = import_module('models.' + args.model_arch)
|
||||
net = module.build_s3fd('train', cfg.NUM_CLASSES)
|
||||
|
||||
|
||||
|
||||
if args.cuda:
|
||||
if args.multigpu:
|
||||
net = torch.nn.DataParallel(net)
|
||||
net = net.cuda()
|
||||
cudnn.benckmark = True
|
||||
|
||||
if args.resume:
|
||||
print('Resuming training, loading {}...'.format(args.resume))
|
||||
net.load_state_dict(torch.load(args.resume))
|
||||
|
||||
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
criterion = MultiBoxLoss(cfg, args.dataset, args.cuda)
|
||||
print('Loading wider dataset...')
|
||||
print('Using the specified args:')
|
||||
print(args)
|
||||
|
||||
|
||||
def train():
|
||||
step_index = 0
|
||||
iteration = 0
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
net.train()
|
||||
losses = 0
|
||||
train_loader_len = len(train_loader)
|
||||
for batch_idx, (images, targets) in enumerate(train_loader):
|
||||
adjust_learning_rate(optimizer, epoch, batch_idx, train_loader_len)
|
||||
|
||||
if args.cuda:
|
||||
images = images.cuda()
|
||||
targets = [ann.cuda()
|
||||
for ann in targets]
|
||||
else:
|
||||
images = images
|
||||
targets = [ann for ann in targets]
|
||||
|
||||
|
||||
t0 = time.time()
|
||||
out = net(images)
|
||||
# backprop
|
||||
optimizer.zero_grad()
|
||||
loss_l, loss_c = criterion(out, targets)
|
||||
loss = loss_l + loss_c
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
t1 = time.time()
|
||||
losses += loss.item()
|
||||
|
||||
if iteration % 10 == 0:
|
||||
tloss = losses / (batch_idx + 1)
|
||||
print('Timer: %.4f' % (t1 - t0))
|
||||
print('epoch:' + repr(epoch) + ' || iter:' +
|
||||
repr(iteration) + ' || Loss:%.4f' % (tloss))
|
||||
print('->> conf loss:{:.4f} || loc loss:{:.4f}'.format(
|
||||
loss_c.item(), loss_l.item()))
|
||||
print('->>lr:{:.6f}'.format(optimizer.param_groups[0]['lr']))
|
||||
|
||||
if iteration != 0 and iteration % args.save_frequency == 0:
|
||||
print('Saving state, iter:', iteration)
|
||||
file = 'rpool_' + args.dataset + '_' + repr(iteration) + '_checkpoint.pth'
|
||||
torch.save(net.state_dict(),
|
||||
os.path.join(args.save_folder, file))
|
||||
iteration += 1
|
||||
|
||||
val(epoch)
|
||||
if iteration == cfg.MAX_STEPS:
|
||||
break
|
||||
|
||||
|
||||
def val(epoch):
|
||||
net.eval()
|
||||
loc_loss = 0
|
||||
conf_loss = 0
|
||||
step = 0
|
||||
t1 = time.time()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (images, targets) in enumerate(val_loader):
|
||||
if args.cuda:
|
||||
images = images.cuda()
|
||||
targets = [ann.cuda()
|
||||
for ann in targets]
|
||||
else:
|
||||
images = images
|
||||
targets = [ann for ann in targets]
|
||||
|
||||
out = net(images)
|
||||
loss_l, loss_c = criterion(out, targets)
|
||||
loss = loss_l + loss_c
|
||||
loc_loss += loss_l.item()
|
||||
conf_loss += loss_c.item()
|
||||
step += 1
|
||||
|
||||
tloss = (loc_loss + conf_loss) / step
|
||||
t2 = time.time()
|
||||
print('Timer: %.4f' % (t2 - t1))
|
||||
print('test epoch:' + repr(epoch) + ' || Loss:%.4f' % (tloss))
|
||||
|
||||
global min_loss
|
||||
if tloss < min_loss:
|
||||
print('Saving best state,epoch', epoch)
|
||||
file = '{}_best_state.pth'.format(args.model_arch)
|
||||
torch.save(net.state_dict(), os.path.join(
|
||||
args.save_folder, file))
|
||||
min_loss = tloss
|
||||
|
||||
|
||||
|
||||
from math import cos, pi
|
||||
def adjust_learning_rate(optimizer, epoch, iteration, num_iter):
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
warmup_epoch = 5
|
||||
warmup_iter = warmup_epoch * num_iter
|
||||
current_iter = iteration + epoch * num_iter
|
||||
max_iter = args.epochs * num_iter
|
||||
|
||||
lr = args.lr * (1 + cos(pi * (current_iter - warmup_iter) / (max_iter - warmup_iter))) / 2
|
||||
|
||||
if epoch < warmup_epoch:
|
||||
lr = args.lr * current_iter / warmup_iter
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .augmentations import *
|
|
@ -0,0 +1,862 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import cv2
|
||||
import numpy as np
|
||||
import types
|
||||
from PIL import Image, ImageEnhance, ImageDraw
|
||||
import math
|
||||
import six
|
||||
|
||||
import sys; sys.path.append('../')
|
||||
from data.choose_config import cfg
|
||||
cfg = cfg.cfg
|
||||
import random
|
||||
|
||||
|
||||
class sampler():
|
||||
|
||||
def __init__(self,
|
||||
max_sample,
|
||||
max_trial,
|
||||
min_scale,
|
||||
max_scale,
|
||||
min_aspect_ratio,
|
||||
max_aspect_ratio,
|
||||
min_jaccard_overlap,
|
||||
max_jaccard_overlap,
|
||||
min_object_coverage,
|
||||
max_object_coverage,
|
||||
use_square=False):
|
||||
self.max_sample = max_sample
|
||||
self.max_trial = max_trial
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.min_aspect_ratio = min_aspect_ratio
|
||||
self.max_aspect_ratio = max_aspect_ratio
|
||||
self.min_jaccard_overlap = min_jaccard_overlap
|
||||
self.max_jaccard_overlap = max_jaccard_overlap
|
||||
self.min_object_coverage = min_object_coverage
|
||||
self.max_object_coverage = max_object_coverage
|
||||
self.use_square = use_square
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
max_xy = np.minimum(box_a[:, 2:], box_b[2:])
|
||||
min_xy = np.maximum(box_a[:, :2], box_b[:2])
|
||||
inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
|
||||
return inter[:, 0] * inter[:, 1]
|
||||
|
||||
|
||||
def jaccard_numpy(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
||||
is simply the intersection over union of two boxes.
|
||||
E.g.:
|
||||
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||
Args:
|
||||
box_a: Multiple bounding boxes, Shape: [num_boxes,4]
|
||||
box_b: Single bounding box, Shape: [4]
|
||||
Return:
|
||||
jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
|
||||
"""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||
(box_a[:, 3] - box_a[:, 1])) # [A,B]
|
||||
area_b = ((box_b[2] - box_b[0]) *
|
||||
(box_b[3] - box_b[1])) # [A,B]
|
||||
union = area_a + area_b - inter
|
||||
return inter / union # [A,B]
|
||||
|
||||
|
||||
class bbox():
|
||||
|
||||
def __init__(self, xmin, ymin, xmax, ymax):
|
||||
self.xmin = xmin
|
||||
self.ymin = ymin
|
||||
self.xmax = xmax
|
||||
self.ymax = ymax
|
||||
|
||||
|
||||
def random_brightness(img):
|
||||
prob = np.random.uniform(0, 1)
|
||||
if prob < cfg.brightness_prob:
|
||||
delta = np.random.uniform(-cfg.brightness_delta,
|
||||
cfg.brightness_delta) + 1
|
||||
img = ImageEnhance.Brightness(img).enhance(delta)
|
||||
return img
|
||||
|
||||
|
||||
def random_contrast(img):
|
||||
prob = np.random.uniform(0, 1)
|
||||
if prob < cfg.contrast_prob:
|
||||
delta = np.random.uniform(-cfg.contrast_delta,
|
||||
cfg.contrast_delta) + 1
|
||||
img = ImageEnhance.Contrast(img).enhance(delta)
|
||||
return img
|
||||
|
||||
|
||||
def random_saturation(img):
|
||||
prob = np.random.uniform(0, 1)
|
||||
if prob < cfg.saturation_prob:
|
||||
delta = np.random.uniform(-cfg.saturation_delta,
|
||||
cfg.saturation_delta) + 1
|
||||
img = ImageEnhance.Color(img).enhance(delta)
|
||||
return img
|
||||
|
||||
|
||||
def random_hue(img):
|
||||
prob = np.random.uniform(0, 1)
|
||||
if prob < cfg.hue_prob:
|
||||
delta = np.random.uniform(-cfg.hue_delta, cfg.hue_delta)
|
||||
img_hsv = np.array(img.convert('HSV'))
|
||||
img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta
|
||||
img = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
|
||||
return img
|
||||
|
||||
|
||||
def distort_image(img):
|
||||
prob = np.random.uniform(0, 1)
|
||||
# Apply different distort order
|
||||
if prob > 0.5:
|
||||
img = random_brightness(img)
|
||||
img = random_contrast(img)
|
||||
img = random_saturation(img)
|
||||
img = random_hue(img)
|
||||
else:
|
||||
img = random_brightness(img)
|
||||
img = random_saturation(img)
|
||||
img = random_hue(img)
|
||||
img = random_contrast(img)
|
||||
return img
|
||||
|
||||
|
||||
def meet_emit_constraint(src_bbox, sample_bbox):
|
||||
center_x = (src_bbox.xmax + src_bbox.xmin) / 2
|
||||
center_y = (src_bbox.ymax + src_bbox.ymin) / 2
|
||||
if center_x >= sample_bbox.xmin and \
|
||||
center_x <= sample_bbox.xmax and \
|
||||
center_y >= sample_bbox.ymin and \
|
||||
center_y <= sample_bbox.ymax:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_bbox(object_bbox, sample_bbox):
|
||||
if object_bbox.xmin >= sample_bbox.xmax or \
|
||||
object_bbox.xmax <= sample_bbox.xmin or \
|
||||
object_bbox.ymin >= sample_bbox.ymax or \
|
||||
object_bbox.ymax <= sample_bbox.ymin:
|
||||
return False
|
||||
else:
|
||||
proj_bbox = bbox(0, 0, 0, 0)
|
||||
sample_width = sample_bbox.xmax - sample_bbox.xmin
|
||||
sample_height = sample_bbox.ymax - sample_bbox.ymin
|
||||
proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width
|
||||
proj_bbox.ymin = (object_bbox.ymin - sample_bbox.ymin) / sample_height
|
||||
proj_bbox.xmax = (object_bbox.xmax - sample_bbox.xmin) / sample_width
|
||||
proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height
|
||||
proj_bbox = clip_bbox(proj_bbox)
|
||||
if bbox_area(proj_bbox) > 0:
|
||||
return proj_bbox
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def transform_labels(bbox_labels, sample_bbox):
|
||||
sample_labels = []
|
||||
for i in range(len(bbox_labels)):
|
||||
sample_label = []
|
||||
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
|
||||
bbox_labels[i][3], bbox_labels[i][4])
|
||||
if not meet_emit_constraint(object_bbox, sample_bbox):
|
||||
continue
|
||||
proj_bbox = project_bbox(object_bbox, sample_bbox)
|
||||
if proj_bbox:
|
||||
sample_label.append(bbox_labels[i][0])
|
||||
sample_label.append(float(proj_bbox.xmin))
|
||||
sample_label.append(float(proj_bbox.ymin))
|
||||
sample_label.append(float(proj_bbox.xmax))
|
||||
sample_label.append(float(proj_bbox.ymax))
|
||||
sample_label = sample_label + bbox_labels[i][5:]
|
||||
sample_labels.append(sample_label)
|
||||
return sample_labels
|
||||
|
||||
|
||||
def expand_image(img, bbox_labels, img_width, img_height):
|
||||
prob = np.random.uniform(0, 1)
|
||||
if prob < cfg.expand_prob:
|
||||
if cfg.expand_max_ratio - 1 >= 0.01:
|
||||
expand_ratio = np.random.uniform(1, cfg.expand_max_ratio)
|
||||
height = int(img_height * expand_ratio)
|
||||
width = int(img_width * expand_ratio)
|
||||
h_off = math.floor(np.random.uniform(0, height - img_height))
|
||||
w_off = math.floor(np.random.uniform(0, width - img_width))
|
||||
expand_bbox = bbox(-w_off / img_width, -h_off / img_height,
|
||||
(width - w_off) / img_width,
|
||||
(height - h_off) / img_height)
|
||||
expand_img = np.ones((height, width, 3))
|
||||
expand_img = np.uint8(expand_img * np.squeeze(cfg.img_mean))
|
||||
expand_img = Image.fromarray(expand_img)
|
||||
expand_img.paste(img, (int(w_off), int(h_off)))
|
||||
bbox_labels = transform_labels(bbox_labels, expand_bbox)
|
||||
return expand_img, bbox_labels, width, height
|
||||
return img, bbox_labels, img_width, img_height
|
||||
|
||||
|
||||
def clip_bbox(src_bbox):
|
||||
src_bbox.xmin = max(min(src_bbox.xmin, 1.0), 0.0)
|
||||
src_bbox.ymin = max(min(src_bbox.ymin, 1.0), 0.0)
|
||||
src_bbox.xmax = max(min(src_bbox.xmax, 1.0), 0.0)
|
||||
src_bbox.ymax = max(min(src_bbox.ymax, 1.0), 0.0)
|
||||
return src_bbox
|
||||
|
||||
|
||||
def bbox_area(src_bbox):
|
||||
if src_bbox.xmax < src_bbox.xmin or src_bbox.ymax < src_bbox.ymin:
|
||||
return 0.
|
||||
else:
|
||||
width = src_bbox.xmax - src_bbox.xmin
|
||||
height = src_bbox.ymax - src_bbox.ymin
|
||||
return width * height
|
||||
|
||||
|
||||
def intersect_bbox(bbox1, bbox2):
|
||||
if bbox2.xmin > bbox1.xmax or bbox2.xmax < bbox1.xmin or \
|
||||
bbox2.ymin > bbox1.ymax or bbox2.ymax < bbox1.ymin:
|
||||
intersection_box = bbox(0.0, 0.0, 0.0, 0.0)
|
||||
else:
|
||||
intersection_box = bbox(
|
||||
max(bbox1.xmin, bbox2.xmin),
|
||||
max(bbox1.ymin, bbox2.ymin),
|
||||
min(bbox1.xmax, bbox2.xmax), min(bbox1.ymax, bbox2.ymax))
|
||||
return intersection_box
|
||||
|
||||
|
||||
def bbox_coverage(bbox1, bbox2):
|
||||
inter_box = intersect_bbox(bbox1, bbox2)
|
||||
intersect_size = bbox_area(inter_box)
|
||||
|
||||
if intersect_size > 0:
|
||||
bbox1_size = bbox_area(bbox1)
|
||||
return intersect_size / bbox1_size
|
||||
else:
|
||||
return 0.
|
||||
|
||||
|
||||
def generate_batch_random_samples(batch_sampler, bbox_labels, image_width,
|
||||
image_height, scale_array, resize_width,
|
||||
resize_height):
|
||||
sampled_bbox = []
|
||||
for sampler in batch_sampler:
|
||||
found = 0
|
||||
for i in range(sampler.max_trial):
|
||||
if found >= sampler.max_sample:
|
||||
break
|
||||
sample_bbox = data_anchor_sampling(
|
||||
sampler, bbox_labels, image_width, image_height, scale_array,
|
||||
resize_width, resize_height)
|
||||
if sample_bbox == 0:
|
||||
break
|
||||
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
|
||||
sampled_bbox.append(sample_bbox)
|
||||
found = found + 1
|
||||
return sampled_bbox
|
||||
|
||||
|
||||
def data_anchor_sampling(sampler, bbox_labels, image_width, image_height,
|
||||
scale_array, resize_width, resize_height):
|
||||
num_gt = len(bbox_labels)
|
||||
# np.random.randint range: [low, high)
|
||||
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
|
||||
|
||||
if num_gt != 0:
|
||||
norm_xmin = bbox_labels[rand_idx][1]
|
||||
norm_ymin = bbox_labels[rand_idx][2]
|
||||
norm_xmax = bbox_labels[rand_idx][3]
|
||||
norm_ymax = bbox_labels[rand_idx][4]
|
||||
|
||||
xmin = norm_xmin * image_width
|
||||
ymin = norm_ymin * image_height
|
||||
wid = image_width * (norm_xmax - norm_xmin)
|
||||
hei = image_height * (norm_ymax - norm_ymin)
|
||||
range_size = 0
|
||||
|
||||
area = wid * hei
|
||||
for scale_ind in range(0, len(scale_array) - 1):
|
||||
if area > scale_array[scale_ind] ** 2 and area < \
|
||||
scale_array[scale_ind + 1] ** 2:
|
||||
range_size = scale_ind + 1
|
||||
break
|
||||
|
||||
if area > scale_array[len(scale_array) - 2]**2:
|
||||
range_size = len(scale_array) - 2
|
||||
scale_choose = 0.0
|
||||
if range_size == 0:
|
||||
rand_idx_size = 0
|
||||
else:
|
||||
# np.random.randint range: [low, high)
|
||||
rng_rand_size = np.random.randint(0, range_size + 1)
|
||||
rand_idx_size = rng_rand_size % (range_size + 1)
|
||||
|
||||
if rand_idx_size == range_size:
|
||||
min_resize_val = scale_array[rand_idx_size] / 2.0
|
||||
max_resize_val = min(2.0 * scale_array[rand_idx_size],
|
||||
2 * math.sqrt(wid * hei))
|
||||
scale_choose = random.uniform(min_resize_val, max_resize_val)
|
||||
else:
|
||||
min_resize_val = scale_array[rand_idx_size] / 2.0
|
||||
max_resize_val = 2.0 * scale_array[rand_idx_size]
|
||||
scale_choose = random.uniform(min_resize_val, max_resize_val)
|
||||
|
||||
sample_bbox_size = wid * resize_width / scale_choose
|
||||
|
||||
w_off_orig = 0.0
|
||||
h_off_orig = 0.0
|
||||
if sample_bbox_size < max(image_height, image_width):
|
||||
if wid <= sample_bbox_size:
|
||||
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
|
||||
xmin)
|
||||
else:
|
||||
w_off_orig = np.random.uniform(xmin,
|
||||
xmin + wid - sample_bbox_size)
|
||||
|
||||
if hei <= sample_bbox_size:
|
||||
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
|
||||
ymin)
|
||||
else:
|
||||
h_off_orig = np.random.uniform(ymin,
|
||||
ymin + hei - sample_bbox_size)
|
||||
|
||||
else:
|
||||
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
|
||||
h_off_orig = np.random.uniform(
|
||||
image_height - sample_bbox_size, 0.0)
|
||||
|
||||
w_off_orig = math.floor(w_off_orig)
|
||||
h_off_orig = math.floor(h_off_orig)
|
||||
|
||||
# Figure out top left coordinates.
|
||||
w_off = 0.0
|
||||
h_off = 0.0
|
||||
w_off = float(w_off_orig / image_width)
|
||||
h_off = float(h_off_orig / image_height)
|
||||
|
||||
sampled_bbox = bbox(w_off, h_off,
|
||||
w_off + float(sample_bbox_size / image_width),
|
||||
h_off + float(sample_bbox_size / image_height))
|
||||
|
||||
return sampled_bbox
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def jaccard_overlap(sample_bbox, object_bbox):
|
||||
if sample_bbox.xmin >= object_bbox.xmax or \
|
||||
sample_bbox.xmax <= object_bbox.xmin or \
|
||||
sample_bbox.ymin >= object_bbox.ymax or \
|
||||
sample_bbox.ymax <= object_bbox.ymin:
|
||||
return 0
|
||||
intersect_xmin = max(sample_bbox.xmin, object_bbox.xmin)
|
||||
intersect_ymin = max(sample_bbox.ymin, object_bbox.ymin)
|
||||
intersect_xmax = min(sample_bbox.xmax, object_bbox.xmax)
|
||||
intersect_ymax = min(sample_bbox.ymax, object_bbox.ymax)
|
||||
intersect_size = (intersect_xmax - intersect_xmin) * (
|
||||
intersect_ymax - intersect_ymin)
|
||||
sample_bbox_size = bbox_area(sample_bbox)
|
||||
object_bbox_size = bbox_area(object_bbox)
|
||||
overlap = intersect_size / (
|
||||
sample_bbox_size + object_bbox_size - intersect_size)
|
||||
return overlap
|
||||
|
||||
|
||||
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
|
||||
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
|
||||
has_jaccard_overlap = False
|
||||
else:
|
||||
has_jaccard_overlap = True
|
||||
if sampler.min_object_coverage == 0 and sampler.max_object_coverage == 0:
|
||||
has_object_coverage = False
|
||||
else:
|
||||
has_object_coverage = True
|
||||
|
||||
if not has_jaccard_overlap and not has_object_coverage:
|
||||
return True
|
||||
found = False
|
||||
for i in range(len(bbox_labels)):
|
||||
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
|
||||
bbox_labels[i][3], bbox_labels[i][4])
|
||||
if has_jaccard_overlap:
|
||||
overlap = jaccard_overlap(sample_bbox, object_bbox)
|
||||
if sampler.min_jaccard_overlap != 0 and \
|
||||
overlap < sampler.min_jaccard_overlap:
|
||||
continue
|
||||
if sampler.max_jaccard_overlap != 0 and \
|
||||
overlap > sampler.max_jaccard_overlap:
|
||||
continue
|
||||
found = True
|
||||
if has_object_coverage:
|
||||
object_coverage = bbox_coverage(object_bbox, sample_bbox)
|
||||
if sampler.min_object_coverage != 0 and \
|
||||
object_coverage < sampler.min_object_coverage:
|
||||
continue
|
||||
if sampler.max_object_coverage != 0 and \
|
||||
object_coverage > sampler.max_object_coverage:
|
||||
continue
|
||||
found = True
|
||||
if found:
|
||||
return True
|
||||
return found
|
||||
|
||||
|
||||
def crop_image_sampling(img, bbox_labels, sample_bbox, image_width,
|
||||
image_height, resize_width, resize_height,
|
||||
min_face_size):
|
||||
# no clipping here
|
||||
xmin = int(sample_bbox.xmin * image_width)
|
||||
xmax = int(sample_bbox.xmax * image_width)
|
||||
ymin = int(sample_bbox.ymin * image_height)
|
||||
ymax = int(sample_bbox.ymax * image_height)
|
||||
w_off = xmin
|
||||
h_off = ymin
|
||||
width = xmax - xmin
|
||||
height = ymax - ymin
|
||||
|
||||
cross_xmin = max(0.0, float(w_off))
|
||||
cross_ymin = max(0.0, float(h_off))
|
||||
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
|
||||
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
|
||||
cross_width = cross_xmax - cross_xmin
|
||||
cross_height = cross_ymax - cross_ymin
|
||||
|
||||
roi_xmin = 0 if w_off >= 0 else abs(w_off)
|
||||
roi_ymin = 0 if h_off >= 0 else abs(h_off)
|
||||
roi_width = cross_width
|
||||
roi_height = cross_height
|
||||
|
||||
roi_y1 = int(roi_ymin)
|
||||
roi_y2 = int(roi_ymin + roi_height)
|
||||
roi_x1 = int(roi_xmin)
|
||||
roi_x2 = int(roi_xmin + roi_width)
|
||||
|
||||
cross_y1 = int(cross_ymin)
|
||||
cross_y2 = int(cross_ymin + cross_height)
|
||||
cross_x1 = int(cross_xmin)
|
||||
cross_x2 = int(cross_xmin + cross_width)
|
||||
|
||||
sample_img = np.zeros((height, width, 3))
|
||||
# print(sample_img.shape)
|
||||
sample_img[roi_y1 : roi_y2, roi_x1 : roi_x2] = \
|
||||
img[cross_y1: cross_y2, cross_x1: cross_x2]
|
||||
sample_img = cv2.resize(
|
||||
sample_img, (resize_width, resize_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
resize_val = resize_width
|
||||
sample_labels = transform_labels_sampling(bbox_labels, sample_bbox,
|
||||
resize_val, min_face_size)
|
||||
return sample_img, sample_labels
|
||||
|
||||
|
||||
def transform_labels_sampling(bbox_labels, sample_bbox, resize_val,
|
||||
min_face_size):
|
||||
sample_labels = []
|
||||
for i in range(len(bbox_labels)):
|
||||
sample_label = []
|
||||
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
|
||||
bbox_labels[i][3], bbox_labels[i][4])
|
||||
if not meet_emit_constraint(object_bbox, sample_bbox):
|
||||
continue
|
||||
proj_bbox = project_bbox(object_bbox, sample_bbox)
|
||||
if proj_bbox:
|
||||
real_width = float((proj_bbox.xmax - proj_bbox.xmin) * resize_val)
|
||||
real_height = float((proj_bbox.ymax - proj_bbox.ymin) * resize_val)
|
||||
if real_width * real_height < float(min_face_size * min_face_size):
|
||||
continue
|
||||
else:
|
||||
sample_label.append(bbox_labels[i][0])
|
||||
sample_label.append(float(proj_bbox.xmin))
|
||||
sample_label.append(float(proj_bbox.ymin))
|
||||
sample_label.append(float(proj_bbox.xmax))
|
||||
sample_label.append(float(proj_bbox.ymax))
|
||||
sample_label = sample_label + bbox_labels[i][5:]
|
||||
sample_labels.append(sample_label)
|
||||
|
||||
return sample_labels
|
||||
|
||||
|
||||
def generate_sample(sampler, image_width, image_height):
|
||||
scale = np.random.uniform(sampler.min_scale, sampler.max_scale)
|
||||
aspect_ratio = np.random.uniform(sampler.min_aspect_ratio,
|
||||
sampler.max_aspect_ratio)
|
||||
aspect_ratio = max(aspect_ratio, (scale**2.0))
|
||||
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
|
||||
|
||||
bbox_width = scale * (aspect_ratio**0.5)
|
||||
bbox_height = scale / (aspect_ratio**0.5)
|
||||
|
||||
# guarantee a squared image patch after cropping
|
||||
if sampler.use_square:
|
||||
if image_height < image_width:
|
||||
bbox_width = bbox_height * image_height / image_width
|
||||
else:
|
||||
bbox_height = bbox_width * image_width / image_height
|
||||
|
||||
xmin_bound = 1 - bbox_width
|
||||
ymin_bound = 1 - bbox_height
|
||||
xmin = np.random.uniform(0, xmin_bound)
|
||||
ymin = np.random.uniform(0, ymin_bound)
|
||||
xmax = xmin + bbox_width
|
||||
ymax = ymin + bbox_height
|
||||
sampled_bbox = bbox(xmin, ymin, xmax, ymax)
|
||||
return sampled_bbox
|
||||
|
||||
|
||||
def generate_batch_samples(batch_sampler, bbox_labels, image_width,
|
||||
image_height):
|
||||
sampled_bbox = []
|
||||
for sampler in batch_sampler:
|
||||
found = 0
|
||||
for i in range(sampler.max_trial):
|
||||
if found >= sampler.max_sample:
|
||||
break
|
||||
sample_bbox = generate_sample(sampler, image_width, image_height)
|
||||
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
|
||||
sampled_bbox.append(sample_bbox)
|
||||
found = found + 1
|
||||
return sampled_bbox
|
||||
|
||||
|
||||
def crop_image(img, bbox_labels, sample_bbox, image_width, image_height,
|
||||
resize_width, resize_height, min_face_size):
|
||||
sample_bbox = clip_bbox(sample_bbox)
|
||||
xmin = int(sample_bbox.xmin * image_width)
|
||||
xmax = int(sample_bbox.xmax * image_width)
|
||||
ymin = int(sample_bbox.ymin * image_height)
|
||||
ymax = int(sample_bbox.ymax * image_height)
|
||||
|
||||
sample_img = img[ymin:ymax, xmin:xmax]
|
||||
resize_val = resize_width
|
||||
sample_labels = transform_labels_sampling(bbox_labels, sample_bbox,
|
||||
resize_val, min_face_size)
|
||||
return sample_img, sample_labels
|
||||
|
||||
|
||||
def to_chw_bgr(image):
|
||||
"""
|
||||
Transpose image from HWC to CHW and from RBG to BGR.
|
||||
Args:
|
||||
image (np.array): an image with HWC and RBG layout.
|
||||
"""
|
||||
# HWC to CHW
|
||||
if len(image.shape) == 3:
|
||||
image = np.swapaxes(image, 1, 2)
|
||||
image = np.swapaxes(image, 1, 0)
|
||||
# RBG to BGR
|
||||
image = image[[2, 1, 0], :, :]
|
||||
return image
|
||||
|
||||
|
||||
def anchor_crop_image_sampling(img,
|
||||
bbox_labels,
|
||||
scale_array,
|
||||
img_width,
|
||||
img_height):
|
||||
mean = np.array([104, 117, 123], dtype=np.float32)
|
||||
maxSize = 12000 # max size
|
||||
infDistance = 9999999
|
||||
bbox_labels = np.array(bbox_labels)
|
||||
scale = np.array([img_width, img_height, img_width, img_height])
|
||||
|
||||
boxes = bbox_labels[:, 1:5] * scale
|
||||
labels = bbox_labels[:, 0]
|
||||
|
||||
boxArea = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
|
||||
# argsort = np.argsort(boxArea)
|
||||
# rand_idx = random.randint(min(len(argsort),6))
|
||||
# print('rand idx',rand_idx)
|
||||
rand_idx = np.random.randint(len(boxArea))
|
||||
rand_Side = boxArea[rand_idx] ** 0.5
|
||||
# rand_Side = min(boxes[rand_idx,2] - boxes[rand_idx,0] + 1,
|
||||
# boxes[rand_idx,3] - boxes[rand_idx,1] + 1)
|
||||
|
||||
distance = infDistance
|
||||
anchor_idx = 5
|
||||
for i, anchor in enumerate(scale_array):
|
||||
if abs(anchor - rand_Side) < distance:
|
||||
distance = abs(anchor - rand_Side)
|
||||
anchor_idx = i
|
||||
|
||||
target_anchor = random.choice(scale_array[0:min(anchor_idx + 1, 5) + 1])
|
||||
ratio = float(target_anchor) / rand_Side
|
||||
ratio = ratio * (2**random.uniform(-1, 1))
|
||||
|
||||
if int(img_height * ratio * img_width * ratio) > maxSize * maxSize:
|
||||
ratio = (maxSize * maxSize / (img_height * img_width))**0.5
|
||||
|
||||
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC,
|
||||
cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
|
||||
interp_method = random.choice(interp_methods)
|
||||
image = cv2.resize(img, None, None, fx=ratio,
|
||||
fy=ratio, interpolation=interp_method)
|
||||
|
||||
boxes[:, 0] *= ratio
|
||||
boxes[:, 1] *= ratio
|
||||
boxes[:, 2] *= ratio
|
||||
boxes[:, 3] *= ratio
|
||||
|
||||
height, width, _ = image.shape
|
||||
|
||||
sample_boxes = []
|
||||
|
||||
xmin = boxes[rand_idx, 0]
|
||||
ymin = boxes[rand_idx, 1]
|
||||
bw = (boxes[rand_idx, 2] - boxes[rand_idx, 0] + 1)
|
||||
bh = (boxes[rand_idx, 3] - boxes[rand_idx, 1] + 1)
|
||||
|
||||
w = h = cfg.INPUT_SIZE
|
||||
|
||||
for _ in range(50):
|
||||
if w < max(height, width):
|
||||
if bw <= w:
|
||||
w_off = random.uniform(xmin + bw - w, xmin)
|
||||
else:
|
||||
w_off = random.uniform(xmin, xmin + bw - w)
|
||||
|
||||
if bh <= h:
|
||||
h_off = random.uniform(ymin + bh - h, ymin)
|
||||
else:
|
||||
h_off = random.uniform(ymin, ymin + bh - h)
|
||||
else:
|
||||
w_off = random.uniform(width - w, 0)
|
||||
h_off = random.uniform(height - h, 0)
|
||||
|
||||
w_off = math.floor(w_off)
|
||||
h_off = math.floor(h_off)
|
||||
|
||||
# convert to integer rect x1,y1,x2,y2
|
||||
rect = np.array(
|
||||
[int(w_off), int(h_off), int(w_off + w), int(h_off + h)])
|
||||
|
||||
# keep overlap with gt box IF center in sampled patch
|
||||
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
|
||||
# mask in all gt boxes that above and to the left of centers
|
||||
m1 = (rect[0] <= boxes[:, 0]) * (rect[1] <= boxes[:, 1])
|
||||
# mask in all gt boxes that under and to the right of centers
|
||||
m2 = (rect[2] >= boxes[:, 2]) * (rect[3] >= boxes[:, 3])
|
||||
# mask in that both m1 and m2 are true
|
||||
mask = m1 * m2
|
||||
|
||||
overlap = jaccard_numpy(boxes, rect)
|
||||
# have any valid boxes? try again if not
|
||||
if not mask.any() and not overlap.max() > 0.7:
|
||||
continue
|
||||
else:
|
||||
sample_boxes.append(rect)
|
||||
|
||||
sampled_labels = []
|
||||
|
||||
if len(sample_boxes) > 0:
|
||||
choice_idx = np.random.randint(len(sample_boxes))
|
||||
choice_box = sample_boxes[choice_idx]
|
||||
# print('crop the box :',choice_box)
|
||||
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
|
||||
m1 = (choice_box[0] < centers[:, 0]) * \
|
||||
(choice_box[1] < centers[:, 1])
|
||||
m2 = (choice_box[2] > centers[:, 0]) * \
|
||||
(choice_box[3] > centers[:, 1])
|
||||
mask = m1 * m2
|
||||
current_boxes = boxes[mask, :].copy()
|
||||
current_labels = labels[mask]
|
||||
current_boxes[:, :2] -= choice_box[:2]
|
||||
current_boxes[:, 2:] -= choice_box[:2]
|
||||
|
||||
if choice_box[0] < 0 or choice_box[1] < 0:
|
||||
new_img_width = width if choice_box[
|
||||
0] >= 0 else width - choice_box[0]
|
||||
new_img_height = height if choice_box[
|
||||
1] >= 0 else height - choice_box[1]
|
||||
image_pad = np.zeros(
|
||||
(new_img_height, new_img_width, 3), dtype=float)
|
||||
image_pad[:, :, :] = mean
|
||||
start_left = 0 if choice_box[0] >= 0 else -choice_box[0]
|
||||
start_top = 0 if choice_box[1] >= 0 else -choice_box[1]
|
||||
image_pad[start_top:, start_left:, :] = image
|
||||
|
||||
choice_box_w = choice_box[2] - choice_box[0]
|
||||
choice_box_h = choice_box[3] - choice_box[1]
|
||||
|
||||
start_left = choice_box[0] if choice_box[0] >= 0 else 0
|
||||
start_top = choice_box[1] if choice_box[1] >= 0 else 0
|
||||
end_right = start_left + choice_box_w
|
||||
end_bottom = start_top + choice_box_h
|
||||
current_image = image_pad[
|
||||
start_top:end_bottom, start_left:end_right, :].copy()
|
||||
image_height, image_width, _ = current_image.shape
|
||||
if cfg.filter_min_face:
|
||||
bbox_w = current_boxes[:, 2] - current_boxes[:, 0]
|
||||
bbox_h = current_boxes[:, 3] - current_boxes[:, 1]
|
||||
bbox_area = bbox_w * bbox_h
|
||||
mask = bbox_area > (cfg.min_face_size * cfg.min_face_size)
|
||||
current_boxes = current_boxes[mask]
|
||||
current_labels = current_labels[mask]
|
||||
for i in range(len(current_boxes)):
|
||||
sample_label = []
|
||||
sample_label.append(current_labels[i])
|
||||
sample_label.append(current_boxes[i][0] / image_width)
|
||||
sample_label.append(current_boxes[i][1] / image_height)
|
||||
sample_label.append(current_boxes[i][2] / image_width)
|
||||
sample_label.append(current_boxes[i][3] / image_height)
|
||||
sampled_labels += [sample_label]
|
||||
sampled_labels = np.array(sampled_labels)
|
||||
else:
|
||||
current_boxes /= np.array([image_width,
|
||||
image_height, image_width, image_height])
|
||||
sampled_labels = np.hstack(
|
||||
(current_labels[:, np.newaxis], current_boxes))
|
||||
|
||||
return current_image, sampled_labels
|
||||
|
||||
current_image = image[choice_box[1]:choice_box[
|
||||
3], choice_box[0]:choice_box[2], :].copy()
|
||||
image_height, image_width, _ = current_image.shape
|
||||
|
||||
if cfg.filter_min_face:
|
||||
bbox_w = current_boxes[:, 2] - current_boxes[:, 0]
|
||||
bbox_h = current_boxes[:, 3] - current_boxes[:, 1]
|
||||
bbox_area = bbox_w * bbox_h
|
||||
mask = bbox_area > (cfg.min_face_size * cfg.min_face_size)
|
||||
current_boxes = current_boxes[mask]
|
||||
current_labels = current_labels[mask]
|
||||
for i in range(len(current_boxes)):
|
||||
sample_label = []
|
||||
sample_label.append(current_labels[i])
|
||||
sample_label.append(current_boxes[i][0] / image_width)
|
||||
sample_label.append(current_boxes[i][1] / image_height)
|
||||
sample_label.append(current_boxes[i][2] / image_width)
|
||||
sample_label.append(current_boxes[i][3] / image_height)
|
||||
sampled_labels += [sample_label]
|
||||
sampled_labels = np.array(sampled_labels)
|
||||
else:
|
||||
current_boxes /= np.array([image_width,
|
||||
image_height, image_width, image_height])
|
||||
sampled_labels = np.hstack(
|
||||
(current_labels[:, np.newaxis], current_boxes))
|
||||
|
||||
return current_image, sampled_labels
|
||||
else:
|
||||
image_height, image_width, _ = image.shape
|
||||
if cfg.filter_min_face:
|
||||
bbox_w = boxes[:, 2] - boxes[:, 0]
|
||||
bbox_h = boxes[:, 3] - boxes[:, 1]
|
||||
bbox_area = bbox_w * bbox_h
|
||||
mask = bbox_area > (cfg.min_face_size * cfg.min_face_size)
|
||||
boxes = boxes[mask]
|
||||
labels = labels[mask]
|
||||
for i in range(len(boxes)):
|
||||
sample_label = []
|
||||
sample_label.append(labels[i])
|
||||
sample_label.append(boxes[i][0] / image_width)
|
||||
sample_label.append(boxes[i][1] / image_height)
|
||||
sample_label.append(boxes[i][2] / image_width)
|
||||
sample_label.append(boxes[i][3] / image_height)
|
||||
sampled_labels += [sample_label]
|
||||
sampled_labels = np.array(sampled_labels)
|
||||
else:
|
||||
boxes /= np.array([image_width, image_height,
|
||||
image_width, image_height])
|
||||
sampled_labels = np.hstack(
|
||||
(labels[:, np.newaxis], boxes))
|
||||
|
||||
return image, sampled_labels
|
||||
|
||||
|
||||
def preprocess(img, bbox_labels, mode, image_path):
|
||||
img_width, img_height = img.size
|
||||
sampled_labels = bbox_labels
|
||||
if mode == 'train':
|
||||
if cfg.apply_distort:
|
||||
img = distort_image(img)
|
||||
if cfg.apply_expand:
|
||||
img, bbox_labels, img_width, img_height = expand_image(
|
||||
img, bbox_labels, img_width, img_height)
|
||||
|
||||
batch_sampler = []
|
||||
prob = np.random.uniform(0., 1.)
|
||||
if prob > cfg.data_anchor_sampling_prob and cfg.anchor_sampling:
|
||||
scale_array = np.array(cfg.ANCHOR_SIZES)#[16, 32, 64, 128, 256, 512])
|
||||
'''
|
||||
batch_sampler.append(
|
||||
sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.6, 0.0, True))
|
||||
sampled_bbox = generate_batch_random_samples(
|
||||
batch_sampler, bbox_labels, img_width, img_height, scale_array,
|
||||
cfg.resize_width, cfg.resize_height)
|
||||
'''
|
||||
img = np.array(img)
|
||||
img, sampled_labels = anchor_crop_image_sampling(
|
||||
img, bbox_labels, scale_array, img_width, img_height)
|
||||
'''
|
||||
if len(sampled_bbox) > 0:
|
||||
idx = int(np.random.uniform(0, len(sampled_bbox)))
|
||||
img, sampled_labels = crop_image_sampling(
|
||||
img, bbox_labels, sampled_bbox[idx], img_width, img_height,
|
||||
cfg.resize_width, cfg.resize_height, cfg.min_face_size)
|
||||
'''
|
||||
img = img.astype('uint8')
|
||||
img = Image.fromarray(img)
|
||||
else:
|
||||
batch_sampler.append(sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, True))
|
||||
batch_sampler.append(sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, True))
|
||||
batch_sampler.append(sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, True))
|
||||
batch_sampler.append(sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, True))
|
||||
batch_sampler.append(sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
|
||||
0.0, True))
|
||||
sampled_bbox = generate_batch_samples(
|
||||
batch_sampler, bbox_labels, img_width, img_height)
|
||||
|
||||
img = np.array(img)
|
||||
if len(sampled_bbox) > 0:
|
||||
idx = int(np.random.uniform(0, len(sampled_bbox)))
|
||||
img, sampled_labels = crop_image(
|
||||
img, bbox_labels, sampled_bbox[idx], img_width, img_height,
|
||||
cfg.resize_width, cfg.resize_height, cfg.min_face_size)
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
interp_mode = [
|
||||
Image.BILINEAR, Image.HAMMING, Image.NEAREST, Image.BICUBIC,
|
||||
Image.LANCZOS
|
||||
]
|
||||
interp_indx = np.random.randint(0, 5)
|
||||
|
||||
img = img.resize((cfg.resize_width, cfg.resize_height),
|
||||
resample=interp_mode[interp_indx])
|
||||
|
||||
img = np.array(img)
|
||||
|
||||
if mode == 'train':
|
||||
mirror = int(np.random.uniform(0, 2))
|
||||
if mirror == 1:
|
||||
img = img[:, ::-1, :]
|
||||
for i in six.moves.xrange(len(sampled_labels)):
|
||||
tmp = sampled_labels[i][1]
|
||||
sampled_labels[i][1] = 1 - sampled_labels[i][3]
|
||||
sampled_labels[i][3] = 1 - tmp
|
||||
|
||||
#img = Image.fromarray(img)
|
||||
img = to_chw_bgr(img)
|
||||
img = img.astype('float32')
|
||||
img -= cfg.img_mean
|
||||
img = img[[2, 1, 0], :, :] # to RGB
|
||||
#img = img * cfg.scale
|
||||
|
||||
return img, sampled_labels
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as data
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
import os.path as osp
|
||||
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import scipy.io as sio
|
||||
|
||||
from data.choose_config import cfg
|
||||
cfg = cfg.cfg
|
||||
|
||||
from torch.autograd import Variable
|
||||
from utils.augmentations import to_chw_bgr
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
parser = argparse.ArgumentParser(description='s3fd evaluatuon wider')
|
||||
parser.add_argument('--model', type=str,
|
||||
default='./weights/rpool_face_c.pth', help='trained model')
|
||||
parser.add_argument('--thresh', default=0.05, type=float,
|
||||
help='Final confidence threshold')
|
||||
parser.add_argument('--model_arch',
|
||||
default='RPool_Face_C', type=str,
|
||||
choices=['RPool_Face_C', 'RPool_Face_Quant', 'RPool_Face_QVGA_monochrome'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--save_folder', type=str,
|
||||
default='rpool_face_predictions', help='folder for saving predictions')
|
||||
parser.add_argument('--subset', type=str,
|
||||
default='val',
|
||||
choices=['val', 'test'],
|
||||
help='choose which set to run testing on')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
if use_cuda:
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
else:
|
||||
torch.set_default_tensor_type('torch.FloatTensor')
|
||||
|
||||
|
||||
def detect_face(net, img, shrink):
|
||||
if shrink != 1:
|
||||
img = cv2.resize(img, None, None, fx=shrink, fy=shrink,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
x = to_chw_bgr(img)
|
||||
x = x.astype('float32')
|
||||
x -= cfg.img_mean
|
||||
x = x[[2, 1, 0], :, :]
|
||||
|
||||
if cfg.IS_MONOCHROME == True:
|
||||
x = 0.299 * x[0] + 0.587 * x[1] + 0.114 * x[2]
|
||||
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
x = torch.from_numpy(x).unsqueeze(0)
|
||||
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
|
||||
y = net(x)
|
||||
detections = y.data
|
||||
detections = detections.cpu().numpy()
|
||||
|
||||
det_conf = detections[0, 1, :, 0]
|
||||
det_xmin = img.shape[1] * detections[0, 1, :, 1] / shrink
|
||||
det_ymin = img.shape[0] * detections[0, 1, :, 2] / shrink
|
||||
det_xmax = img.shape[1] * detections[0, 1, :, 3] / shrink
|
||||
det_ymax = img.shape[0] * detections[0, 1, :, 4] / shrink
|
||||
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
|
||||
|
||||
keep_index = np.where(det[:, 4] >= args.thresh)[0]
|
||||
det = det[keep_index, :]
|
||||
|
||||
return det
|
||||
|
||||
|
||||
def multi_scale_test(net, image, max_im_shrink):
|
||||
# shrink detecting and shrink only detect big face
|
||||
st = 0.5 if max_im_shrink >= 0.75 else 0.5 * max_im_shrink
|
||||
det_s = detect_face(net, image, st)
|
||||
index = np.where(np.maximum(
|
||||
det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1) > 30)[0]
|
||||
det_s = det_s[index, :]
|
||||
|
||||
# enlarge one times
|
||||
bt = min(2, max_im_shrink) if max_im_shrink > 1 else (
|
||||
st + max_im_shrink) / 2
|
||||
det_b = detect_face(net, image, bt)
|
||||
|
||||
# enlarge small image x times for small face
|
||||
if max_im_shrink > 2:
|
||||
bt *= 2
|
||||
while bt < max_im_shrink:
|
||||
det_b = np.row_stack((det_b, detect_face(net, image, bt)))
|
||||
bt *= 2
|
||||
det_b = np.row_stack((det_b, detect_face(net, image, max_im_shrink)))
|
||||
|
||||
# enlarge only detect small face
|
||||
if bt > 1:
|
||||
index = np.where(np.minimum(
|
||||
det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
|
||||
det_b = det_b[index, :]
|
||||
else:
|
||||
index = np.where(np.maximum(
|
||||
det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
|
||||
det_b = det_b[index, :]
|
||||
|
||||
return det_s, det_b
|
||||
|
||||
|
||||
def flip_test(net, image, shrink):
|
||||
image_f = cv2.flip(image, 1)
|
||||
det_f = detect_face(net, image_f, shrink)
|
||||
|
||||
det_t = np.zeros(det_f.shape)
|
||||
det_t[:, 0] = image.shape[1] - det_f[:, 2]
|
||||
det_t[:, 1] = det_f[:, 1]
|
||||
det_t[:, 2] = image.shape[1] - det_f[:, 0]
|
||||
det_t[:, 3] = det_f[:, 3]
|
||||
det_t[:, 4] = det_f[:, 4]
|
||||
return det_t
|
||||
|
||||
|
||||
def bbox_vote(det):
|
||||
order = det[:, 4].ravel().argsort()[::-1]
|
||||
det = det[order, :]
|
||||
while det.shape[0] > 0:
|
||||
# IOU
|
||||
area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
|
||||
xx1 = np.maximum(det[0, 0], det[:, 0])
|
||||
yy1 = np.maximum(det[0, 1], det[:, 1])
|
||||
xx2 = np.minimum(det[0, 2], det[:, 2])
|
||||
yy2 = np.minimum(det[0, 3], det[:, 3])
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
o = inter / (area[0] + area[:] - inter)
|
||||
|
||||
# get needed merge det and delete these det
|
||||
merge_index = np.where(o >= 0.3)[0]
|
||||
det_accu = det[merge_index, :]
|
||||
det = np.delete(det, merge_index, 0)
|
||||
|
||||
if merge_index.shape[0] <= 1:
|
||||
continue
|
||||
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
|
||||
max_score = np.max(det_accu[:, 4])
|
||||
det_accu_sum = np.zeros((1, 5))
|
||||
det_accu_sum[:, 0:4] = np.sum(
|
||||
det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
|
||||
det_accu_sum[:, 4] = max_score
|
||||
try:
|
||||
dets = np.row_stack((dets, det_accu_sum))
|
||||
except:
|
||||
dets = det_accu_sum
|
||||
|
||||
dets = dets[0:750, :]
|
||||
return dets
|
||||
|
||||
|
||||
def get_data():
|
||||
subset = args.subset
|
||||
|
||||
WIDER_ROOT = os.path.join(cfg.HOME, 'WIDER_FACE')
|
||||
if subset == 'val':
|
||||
wider_face = sio.loadmat(
|
||||
os.path.join(WIDER_ROOT, 'wider_face_split',
|
||||
'wider_face_val.mat'))
|
||||
else:
|
||||
wider_face = sio.loadmat(
|
||||
os.path.join(WIDER_ROOT, 'wider_face_split',
|
||||
'wider_face_test.mat'))
|
||||
event_list = wider_face['event_list']
|
||||
file_list = wider_face['file_list']
|
||||
del wider_face
|
||||
|
||||
imgs_path = os.path.join(
|
||||
cfg.FACE.WIDER_DIR, 'WIDER_{}'.format(subset), 'images')
|
||||
save_path = './{}'.format(args.save_folder)
|
||||
|
||||
return event_list, file_list, imgs_path, save_path
|
||||
|
||||
if __name__ == '__main__':
|
||||
event_list, file_list, imgs_path, save_path = get_data()
|
||||
cfg.USE_NMS = False
|
||||
|
||||
module = import_module('models.' + args.model_arch)
|
||||
net = module.build_s3fd('test', cfg.NUM_CLASSES)
|
||||
|
||||
net = torch.nn.DataParallel(net)
|
||||
|
||||
|
||||
checkpoint_dict = torch.load(args.model)
|
||||
|
||||
model_dict = net.state_dict()
|
||||
|
||||
|
||||
model_dict.update(checkpoint_dict)
|
||||
net.load_state_dict(model_dict)
|
||||
|
||||
|
||||
net.eval()
|
||||
|
||||
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
cudnn.benckmark = True
|
||||
|
||||
|
||||
counter = 0
|
||||
|
||||
for index, event in enumerate(event_list):
|
||||
filelist = file_list[index][0]
|
||||
path = os.path.join(save_path, str(event[0][0]))#.encode('utf-8'))
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
for num, file in enumerate(filelist):
|
||||
im_name = str(file[0][0])#.encode('utf-8')
|
||||
in_file = os.path.join(imgs_path, event[0][0], im_name[:] + '.jpg')
|
||||
img = Image.open(in_file)
|
||||
if img.mode == 'L':
|
||||
img = img.convert('RGB')
|
||||
img = np.array(img)
|
||||
|
||||
|
||||
max_im_shrink = np.sqrt(
|
||||
1700 * 1200 / (img.shape[0] * img.shape[1]))
|
||||
|
||||
shrink = max_im_shrink if max_im_shrink < 1 else 1
|
||||
counter += 1
|
||||
|
||||
t1 = time.time()
|
||||
det0 = detect_face(net, img, shrink)
|
||||
|
||||
det1 = flip_test(net, img, shrink) # flip test
|
||||
[det2, det3] = multi_scale_test(net, img, max_im_shrink)
|
||||
|
||||
det = np.row_stack((det0, det1, det2, det3))
|
||||
dets = bbox_vote(det)
|
||||
|
||||
t2 = time.time()
|
||||
print('Detect %04d th image costs %.4f' % (counter, t2 - t1))
|
||||
|
||||
fout = open(osp.join(save_path, str(event[0][
|
||||
0]), im_name + '.txt'), 'w')
|
||||
fout.write('{:s}\n'.format(str(event[0][0]) + '/' + im_name + '.jpg'))
|
||||
fout.write('{:d}\n'.format(dets.shape[0]))
|
||||
for i in range(dets.shape[0]):
|
||||
xmin = dets[i][0]
|
||||
ymin = dets[i][1]
|
||||
xmax = dets[i][2]
|
||||
ymax = dets[i][3]
|
||||
score = dets[i][4]
|
||||
fout.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.
|
||||
format(xmin, ymin, (xmax - xmin + 1), (ymax - ymin + 1), score))
|
|
@ -0,0 +1,71 @@
|
|||
# Code for Visual Wake Words experiments with RNNPool
|
||||
|
||||
The Visual Wake Word challenge is a binary classification problem of detecting whether a person is present in
|
||||
an image or not, as introduced by [Chowdhery et. al](https://arxiv.org/abs/1906.05721).
|
||||
|
||||
## Dataset
|
||||
The Visual Wake Words Dataset is derived from the publicly available [COCO](cocodataset.org/#/home) dataset. The Visual Wake Words Challenge evaluates accuracy on the [minival image ids](https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_minival_ids.txt),
|
||||
and for training uses the remaining 115k images of the COCO training/validation dataset. The process of creating the Visual Wake Words dataset from COCO dataset is as follows.
|
||||
Each image is assigned a label 1 or 0.
|
||||
The label 1 is assigned as long as it has at least one bounding box corresponding
|
||||
to the object of interest (e.g. person) with the box area greater than a certain threshold
|
||||
(e.g. 0.5% of the image area).
|
||||
|
||||
To download the COCO dataset use the script `download_coco.sh`
|
||||
```bash
|
||||
bash scripts/download_mscoco.sh path-to-mscoco-dataset
|
||||
```
|
||||
|
||||
To create COCO annotation files that converts to the minival split use:
|
||||
`scripts/create_coco_train_minival_split.py`
|
||||
|
||||
```bash
|
||||
TRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_train2014.json"
|
||||
VAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_val2014.json"
|
||||
DIR="path-to-mscoco-dataset/annotations/"
|
||||
python scripts/create_coco_train_minival_split.py \
|
||||
--train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
|
||||
--val_annotations_file="${VAL_ANNOTATIONS_FILE}" \
|
||||
--output_dir="${DIR}"
|
||||
```
|
||||
|
||||
|
||||
To generate the new annotations, use the script `scripts/create_visualwakewords_annotations.py`.
|
||||
```bash
|
||||
MAXITRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_maxitrain.json"
|
||||
MINIVAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_minival.json"
|
||||
VWW_OUTPUT_DIR="new-path-to-visualwakewords-dataset/annotations/"
|
||||
python scripts/create_visualwakewords_annotations.py \
|
||||
--train_annotations_file="${MAXITRAIN_ANNOTATIONS_FILE}" \
|
||||
--val_annotations_file="${MINIVAL_ANNOTATIONS_FILE}" \
|
||||
--output_dir="${VWW_OUTPUT_DIR}" \
|
||||
--threshold=0.005 \
|
||||
--foreground_class='person'
|
||||
```
|
||||
|
||||
|
||||
# Training
|
||||
|
||||
```bash
|
||||
python train_visualwakewords.py \
|
||||
--model_arch model_mobilenet_rnnpool \
|
||||
--lr 0.05 \
|
||||
--epochs 900 \
|
||||
--data "path-to-mscoco-dataset" \
|
||||
--ann "new-path-to-visualwakewords-dataset"
|
||||
```
|
||||
Specify the paths used for storing MS COCO dataset and the Visual Wakeword dataset as used in dataset creation steps in --data and --ann respectively. This script should reach a validation accuracy of about 89.57 upon completion.
|
||||
|
||||
# Evaluation
|
||||
|
||||
```bash
|
||||
python eval.py \
|
||||
--weights vww_rnnpool.pth \
|
||||
--model_arch model_mobilenet_rnnpool \
|
||||
--image_folder images \
|
||||
```
|
||||
|
||||
The weights argument is the saved checkpoint of the model trained with architecture which is passed in model_arch argument. The folder with images for evaluation has to be passed in image_folder argument. This script will print 'Person present' or 'No person present' for each image in the folder specified.
|
||||
|
||||
|
||||
Dataset creation code is from https://github.com/Mxbonn/visualwakewords/
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from importlib import import_module
|
||||
import skimage
|
||||
from skimage import filters
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
|
||||
#Arg parser
|
||||
parser = argparse.ArgumentParser(description='PyTorch VisualWakeWords evaluation')
|
||||
parser.add_argument('--weights', default=None, type=str, help='load from checkpoint')
|
||||
parser.add_argument('--model_arch',
|
||||
default='model_mobilenet_rnnpool', type=str,
|
||||
choices=['model_mobilenet_rnnpool', 'model_mobilenet_2rnnpool'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--image_folder', default=None, type=str, help='folder containing images')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
module = import_module(args.model_arch)
|
||||
model = module.mobilenetv2_rnnpool(num_classes=2, width_mult=0.35, last_channel=320)
|
||||
model = model.to(device)
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
||||
|
||||
checkpoint = torch.load(args.weights)
|
||||
checkpoint_dict = checkpoint['model']
|
||||
model_dict = model.state_dict()
|
||||
model_dict.update(checkpoint_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
model.eval()
|
||||
img_path = args.image_folder
|
||||
img_list = [os.path.join(img_path, x)
|
||||
for x in os.listdir(img_path) if x.endswith('bmp')]
|
||||
|
||||
for path in sorted(img_list):
|
||||
img = Image.open(path).convert('RGB')
|
||||
img = transform_test(img)
|
||||
img = (img.cuda())
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
out = model(img)
|
||||
|
||||
print(path)
|
||||
print(out)
|
||||
if out[0][0]>0.15:
|
||||
print('No person present')
|
||||
else:
|
||||
print('Person present')
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torch.utils.checkpoint as cp
|
||||
from collections import OrderedDict
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import sys; sys.path.append('..')
|
||||
from rnnpool import *
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenetv2_rnnpool']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes, momentum=0.01),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup, momentum=0.01),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=1000,
|
||||
width_mult=0.5,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
last_channel = 1280):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 8
|
||||
#last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.features_init = ConvBNReLU(3, input_channel, stride=2)
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=(6,6),stride=(4,4))
|
||||
|
||||
self.rnn_model = RNNPool(6, 6, 8, 8, input_channel)#num_init_features)
|
||||
self.fold = nn.Fold(kernel_size=(1,1),output_size=(27,27))
|
||||
|
||||
self.rnn_model_end = RNNPool(7, 7, int(self.last_channel/4), int(self.last_channel/4), self.last_channel)
|
||||
|
||||
features=[]
|
||||
|
||||
input_channel = 32
|
||||
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
#nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
|
||||
x = self.features_init(x)
|
||||
|
||||
patches = self.unfold(x)
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,8,6,6))
|
||||
|
||||
|
||||
output_x = int((x.shape[2]-6)/4 + 1)
|
||||
output_y = int((x.shape[3]-6)/4 + 1)
|
||||
|
||||
rnnX = self.rnn_model(patches, int(batch_size)*output_x*output_y)
|
||||
|
||||
x = torch.stack(torch.split(rnnX, split_size_or_sections=int(batch_size), dim=0),dim=2)
|
||||
|
||||
x = self.fold(x)
|
||||
|
||||
x = F.pad(x, (0,1,0,1), mode='replicate')
|
||||
|
||||
x = self.features(x)
|
||||
x = self.rnn_model_end(x, batch_size)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilenetv2_rnnpool(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torch.utils.checkpoint as cp
|
||||
from collections import OrderedDict
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
from edgeml_pytorch.graph.rnnpool import *
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenetv2_rnnpool']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes, momentum=0.01),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup, momentum=0.01),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=1000,
|
||||
width_mult=0.5,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
last_channel = 1280):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 8
|
||||
#last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.features_init = ConvBNReLU(3, input_channel, stride=2)
|
||||
|
||||
self.unfold = nn.Unfold(kernel_size=(6,6),stride=(4,4))
|
||||
|
||||
self.rnn_model = RNNPool(6, 6, 8, 8, input_channel)#num_init_features)
|
||||
self.fold = nn.Fold(kernel_size=(1,1),output_size=(27,27))
|
||||
|
||||
|
||||
features=[]
|
||||
|
||||
input_channel = 32
|
||||
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
|
||||
|
||||
x = self.features_init(x)
|
||||
|
||||
patches = self.unfold(x)
|
||||
patches = torch.cat(torch.unbind(patches,dim=2),dim=0)
|
||||
patches = torch.reshape(patches,(-1,8,6,6))
|
||||
|
||||
|
||||
output_x = int((x.shape[2]-6)/4 + 1)
|
||||
output_y = int((x.shape[3]-6)/4 + 1)
|
||||
|
||||
rnnX = self.rnn_model(patches, int(batch_size)*output_x*output_y)
|
||||
|
||||
x = torch.stack(torch.split(rnnX, split_size_or_sections=int(batch_size), dim=0),dim=2)
|
||||
|
||||
x = self.fold(x)
|
||||
|
||||
x = F.pad(x, (0,1,0,1), mode='replicate')
|
||||
|
||||
x = self.features(x)
|
||||
x = x.mean([2, 3])
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilenetv2_rnnpool(pretrained=False, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
|
@ -0,0 +1,9 @@
|
|||
pycocotools
|
||||
pyvww
|
||||
easydict==1.9
|
||||
importlib-metadata==1.5.0
|
||||
matplotlib==3.2.1
|
||||
opencv-python-headless==4.2.0.32
|
||||
scikit-image==0.15.0
|
||||
tensorboard==1.14.0
|
||||
tensorboardX==1.9
|
|
@ -0,0 +1,120 @@
|
|||
## Code from https://github.com/Mxbonn/visualwakewords
|
||||
|
||||
|
||||
"""Create maxitrain and minival annotations.
|
||||
This script generates a new train validation split with 115k training and 8k validation images.
|
||||
Based on the split used by Google
|
||||
(https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_minival_ids.txt).
|
||||
|
||||
Usage:
|
||||
From this folder, run the following commands: (2014 can be replaced by 2017 if you downloaded the 2017 dataset)
|
||||
TRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_train2014.json"
|
||||
VAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_val2014.json"
|
||||
OUTPUT_DIR="path-to-mscoco-dataset/annotations/"
|
||||
python create_coco_train_minival_split.py \
|
||||
--train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
|
||||
--val_annotations_file="${VAL_ANNOTATIONS_FILE}" \
|
||||
--output_dir="${OUTPUT_DIR}"
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def create_maxitrain_minival(train_file, val_file, output_dir):
|
||||
""" Generate maxitrain and minival annotations files.
|
||||
Loads COCO 2014/2017 train and validation json files and creates a new split with
|
||||
115k training images and 8k validation images.
|
||||
Based on the split used by Google
|
||||
(https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_minival_ids.txt).
|
||||
Args:
|
||||
train_file: JSON file containing COCO 2014 or 2017 train annotations
|
||||
val_file: JSON file containing COCO 2014 or 2017 validation annotations
|
||||
output_dir: Directory where the new annotation files will be stored.
|
||||
"""
|
||||
maxitrain_path = os.path.join(
|
||||
output_dir, 'instances_maxitrain.json')
|
||||
minival_path = os.path.join(
|
||||
output_dir, 'instances_minival.json')
|
||||
train_json = json.load(open(train_file, 'r'))
|
||||
val_json = json.load(open(val_file, 'r'))
|
||||
|
||||
info = train_json['info']
|
||||
categories = train_json['categories']
|
||||
licenses = train_json['licenses']
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
file_path = os.path.join(dir_path, 'mscoco_minival_ids.txt')
|
||||
minival_ids_f = open(file_path, 'r')
|
||||
minival_ids = minival_ids_f.readlines()
|
||||
minival_ids = [int(i) for i in minival_ids]
|
||||
|
||||
train_images = train_json['images']
|
||||
val_images = val_json['images']
|
||||
train_annotations = train_json['annotations']
|
||||
val_annotations = val_json['annotations']
|
||||
|
||||
maxitrain_images = []
|
||||
minival_images = []
|
||||
maxitrain_annotations = []
|
||||
minival_annotations = []
|
||||
|
||||
for _images in [train_images, val_images]:
|
||||
for img in _images:
|
||||
img_id = img['id']
|
||||
if img_id in minival_ids:
|
||||
minival_images.append(img)
|
||||
else:
|
||||
maxitrain_images.append(img)
|
||||
|
||||
for _annotations in [train_annotations, val_annotations]:
|
||||
for ann in _annotations:
|
||||
img_id = ann['image_id']
|
||||
if img_id in minival_ids:
|
||||
minival_annotations.append(ann)
|
||||
else:
|
||||
maxitrain_annotations.append(ann)
|
||||
|
||||
with open(maxitrain_path, 'w') as fp:
|
||||
json.dump(
|
||||
{
|
||||
"info": info,
|
||||
"licenses": licenses,
|
||||
'images': maxitrain_images,
|
||||
'annotations': maxitrain_annotations,
|
||||
'categories': categories,
|
||||
}, fp)
|
||||
|
||||
with open(minival_path, 'w') as fp:
|
||||
json.dump(
|
||||
{
|
||||
"info": info,
|
||||
"licenses": licenses,
|
||||
'images': minival_images,
|
||||
'annotations': minival_annotations,
|
||||
'categories': categories,
|
||||
}, fp)
|
||||
|
||||
|
||||
def main(args):
|
||||
output_dir = os.path.realpath(os.path.expanduser(args.output_dir))
|
||||
train_annotations_file = os.path.realpath(os.path.expanduser(args.train_annotations_file))
|
||||
val_annotations_file = os.path.realpath(os.path.expanduser(args.val_annotations_file))
|
||||
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
create_maxitrain_minival(train_annotations_file, val_annotations_file, output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser(description="Script that takes the 2014/2017 training and validation annotations and"
|
||||
"creates a train split of 115k images and a minival of 8k.")
|
||||
parser.add_argument('--train_annotations_file', type=str, required=True,
|
||||
help='COCO2014/2017 Training annotations JSON file')
|
||||
parser.add_argument('--val_annotations_file', type=str, required=True,
|
||||
help='COCO2014/2017 Validation annotations JSON file')
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Output directory where the maxitrain and minival annotations files will be stored')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,217 @@
|
|||
## Code from https://github.com/Mxbonn/visualwakewords
|
||||
|
||||
|
||||
"""Create Visual Wakewords annotations.
|
||||
This script generates the Visual WakeWords dataset annotations from the raw COCO dataset.
|
||||
The resulting annotations can then be used with `pyvww.utils.VisualWakeWords` and
|
||||
`pyvww.pytorch.VisualWakeWordsClassification`.
|
||||
|
||||
Visual WakeWords Dataset is derived from the COCO dataset to design tiny models
|
||||
classifying two classes, such as person/not-person. The COCO annotations
|
||||
are filtered to two classes: foreground_class and background
|
||||
(for e.g. person and not-person). Bounding boxes for small objects
|
||||
with area less than 5% of the image area are filtered out.
|
||||
The resulting annotations file follows the COCO data format.
|
||||
{
|
||||
"info" : info,
|
||||
"images" : [image],
|
||||
"annotations" : [annotation],
|
||||
"licenses" : [license],
|
||||
}
|
||||
|
||||
info{
|
||||
"year" : int,
|
||||
"version" : str,
|
||||
"description" : str,
|
||||
"url" : str,
|
||||
}
|
||||
|
||||
image{
|
||||
"id" : int,
|
||||
"width" : int,
|
||||
"height" : int,
|
||||
"file_name" : str,
|
||||
"license" : int,
|
||||
"flickr_url" : str,
|
||||
"coco_url" : str,
|
||||
"date_captured" : datetime,
|
||||
}
|
||||
|
||||
license{
|
||||
"id" : int,
|
||||
"name" : str,
|
||||
"url" : str,
|
||||
}
|
||||
|
||||
annotation{
|
||||
"id" : int,
|
||||
"image_id" : int,
|
||||
"category_id" : int,
|
||||
"area" : float,
|
||||
"bbox" : [x,y,width,height],
|
||||
"iscrowd" : 0 or 1,
|
||||
}
|
||||
|
||||
Example usage:
|
||||
From this folder, run the following commands:
|
||||
bash download_mscoco.sh path-to-mscoco-dataset
|
||||
TRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_train2014.json"
|
||||
VAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_val2014.json"
|
||||
DIR="path-to-mscoco-dataset/annotations/"
|
||||
python create_coco_train_minival_split.py \
|
||||
--train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
|
||||
--val_annotations_file="${VAL_ANNOTATIONS_FILE}" \
|
||||
--output_dir="${DIR}"
|
||||
MAXITRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_maxitrain.json"
|
||||
MINIVAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_minival.json"
|
||||
VWW_OUTPUT_DIR="new-path-to-visualwakewords-dataset/annotations/"
|
||||
python create_visualwakewords_annotations.py \
|
||||
--train_annotations_file="${MAXITRAIN_ANNOTATIONS_FILE}" \
|
||||
--val_annotations_file="${MINIVAL_ANNOTATIONS_FILE}" \
|
||||
--output_dir="${VWW_OUTPUT_DIR}" \
|
||||
--threshold=0.005 \
|
||||
--foreground_class='person'
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
|
||||
def create_visual_wakeword_annotations(annotations_file,
|
||||
visualwakewords_annotations_path,
|
||||
object_area_threshold,
|
||||
foreground_class_name):
|
||||
"""Generate visual wake words annotations file.
|
||||
Loads COCO annotation json files and filters to foreground_class_name/not-foreground_class_name
|
||||
(by default it will be person/not-person) to generate visual wake words annotations file.
|
||||
Each image is assigned a label 1 or 0. The label 1 is assigned as long
|
||||
as it has at least one foreground_class_name (e.g. person)
|
||||
bounding box greater than object_area_threshold (e.g. 5% of the image area).
|
||||
Args:
|
||||
annotations_file: JSON file containing COCO bounding box annotations
|
||||
visualwakewords_annotations_path: output path to annotations file
|
||||
object_area_threshold: threshold on fraction of image area below which
|
||||
small object bounding boxes are filtered
|
||||
foreground_class_name: category from COCO dataset that is filtered by
|
||||
the visual wakewords dataset
|
||||
"""
|
||||
print('Processing {}...'.format(annotations_file))
|
||||
coco = COCO(annotations_file)
|
||||
|
||||
info = {"description": "Visual Wake Words Dataset",
|
||||
"url": "https://arxiv.org/abs/1906.05721",
|
||||
"version": "1.0",
|
||||
"year": 2019,
|
||||
}
|
||||
|
||||
# default object of interest is person
|
||||
foreground_class_id = 1
|
||||
dataset = coco.dataset
|
||||
licenses = dataset['licenses']
|
||||
|
||||
images = dataset['images']
|
||||
# Create category index
|
||||
foreground_category = None
|
||||
background_category = {'supercategory': 'background', 'id': 0, 'name': 'background'}
|
||||
for category in dataset['categories']:
|
||||
if category['name'] == foreground_class_name:
|
||||
foreground_class_id = category['id']
|
||||
foreground_category = category
|
||||
foreground_category['id'] = 1
|
||||
background_category['name'] = "not-{}".format(foreground_category['name'])
|
||||
categories = [background_category, foreground_category]
|
||||
|
||||
if not 'annotations' in dataset:
|
||||
raise KeyError('Need annotations in json file to build the dataset.')
|
||||
new_ann_id = 0
|
||||
annotations = []
|
||||
positive_img_ids = set()
|
||||
foreground_imgs_ids = coco.getImgIds(catIds=foreground_class_id)
|
||||
for img_id in foreground_imgs_ids:
|
||||
img = coco.imgs[img_id]
|
||||
img_area = img['height'] * img['width']
|
||||
for ann_id in coco.getAnnIds(imgIds=img_id, catIds=foreground_class_id):
|
||||
ann = coco.anns[ann_id]
|
||||
if 'area' in ann:
|
||||
normalized_ann_area = ann['area'] / img_area
|
||||
if normalized_ann_area > object_area_threshold:
|
||||
new_ann = {
|
||||
"id": new_ann_id,
|
||||
"image_id": img_id,
|
||||
"category_id": 1,
|
||||
"area": ann["area"],
|
||||
"bbox": ann["bbox"],
|
||||
"iscrowd": ann["iscrowd"],
|
||||
}
|
||||
annotations.append(new_ann)
|
||||
positive_img_ids.add(img_id)
|
||||
new_ann_id += 1
|
||||
print("There are {} images that now have label {}, of the {} images in total.".format(len(positive_img_ids),
|
||||
foreground_class_name,
|
||||
len(coco.imgs)))
|
||||
negative_img_ids = list(set(coco.imgs.keys()) - positive_img_ids)
|
||||
for img_id in negative_img_ids:
|
||||
new_ann = {
|
||||
"id": new_ann_id,
|
||||
"image_id": img_id,
|
||||
"category_id": 0,
|
||||
"area": 0.0,
|
||||
"bbox": [],
|
||||
"iscrowd": 0,
|
||||
}
|
||||
annotations.append(new_ann)
|
||||
new_ann_id += 1
|
||||
|
||||
# Output Visual WakeWords annotations and labels
|
||||
with open(visualwakewords_annotations_path, 'w') as fp:
|
||||
json.dump(
|
||||
{
|
||||
"info": info,
|
||||
"licenses": licenses,
|
||||
'images': images,
|
||||
'annotations': annotations,
|
||||
'categories': categories,
|
||||
}, fp)
|
||||
|
||||
|
||||
def main(args):
|
||||
output_dir = os.path.realpath(os.path.expanduser(args.output_dir))
|
||||
train_annotations_file = os.path.realpath(os.path.expanduser(args.train_annotations_file))
|
||||
val_annotations_file = os.path.realpath(os.path.expanduser(args.val_annotations_file))
|
||||
visualwakewords_annotations_train = os.path.join(
|
||||
output_dir, 'instances_train.json')
|
||||
visualwakewords_annotations_val = os.path.join(
|
||||
output_dir, 'instances_val.json')
|
||||
small_object_area_threshold = args.threshold
|
||||
foreground_class_of_interest = args.foreground_class
|
||||
|
||||
# Create the Visual WakeWords annotations from COCO annotations
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
create_visual_wakeword_annotations(
|
||||
train_annotations_file, visualwakewords_annotations_train,
|
||||
small_object_area_threshold, foreground_class_of_interest)
|
||||
create_visual_wakeword_annotations(
|
||||
val_annotations_file, visualwakewords_annotations_val,
|
||||
small_object_area_threshold, foreground_class_of_interest)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--train_annotations_file', type=str, required=True,
|
||||
help='(COCO) Training annotations JSON file')
|
||||
parser.add_argument('--val_annotations_file', type=str, required=True,
|
||||
help='(COCO) Validation annotations JSON file')
|
||||
parser.add_argument('--output_dir', type=str, default='/tmp/visualwakewords/',
|
||||
help='Output directory where the Visual WakeWords annotations files be stored')
|
||||
parser.add_argument('--threshold', type=float, default=0.005,
|
||||
help='Threshold of fraction of image area below which small objects are filtered.')
|
||||
parser.add_argument('--foreground_class', type=str, default='person',
|
||||
help='Annotations will have a label indicating if this object is present or absent'
|
||||
'in the scene (default is person/not-person).')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2020 Maxim Bonnaerens. All Rights Reserved.
|
||||
#
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# File modified from tensorflow/models/research/slim/datasets/download_mscoco.sh
|
||||
|
||||
# Script to download the COCO dataset. See
|
||||
# http://cocodataset.org/#overview for an overview of the dataset.
|
||||
#
|
||||
# usage:
|
||||
# bash scripts/download_mscoco.sh path-to-COCO-dataset
|
||||
#
|
||||
set -e
|
||||
|
||||
YEAR=${2:-2014}
|
||||
if [ -z "$1" ]; then
|
||||
echo "usage download_mscoco.sh [data dir] (2014|2017)"
|
||||
exit
|
||||
fi
|
||||
|
||||
if [ "$(uname)" == "Darwin" ]; then
|
||||
UNZIP="tar -xf"
|
||||
else
|
||||
UNZIP="unzip -nq"
|
||||
fi
|
||||
|
||||
# Create the output directories.
|
||||
OUTPUT_DIR="${1%/}"
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Helper function to download and unpack a .zip file.
|
||||
function download_and_unzip() {
|
||||
local BASE_URL=${1}
|
||||
local FILENAME=${2}
|
||||
|
||||
if [ ! -f "${FILENAME}" ]; then
|
||||
echo "Downloading ${FILENAME} to $(pwd)"
|
||||
wget -nd -c "${BASE_URL}/${FILENAME}"
|
||||
else
|
||||
echo "Skipping download of ${FILENAME}"
|
||||
fi
|
||||
echo "Unzipping ${FILENAME}"
|
||||
${UNZIP} "${FILENAME}"
|
||||
rm "${FILENAME}"
|
||||
}
|
||||
|
||||
cd "${OUTPUT_DIR}"
|
||||
|
||||
# Download the images.
|
||||
BASE_IMAGE_URL="http://images.cocodataset.org/zips"
|
||||
|
||||
TRAIN_IMAGE_FILE="train${YEAR}.zip"
|
||||
download_and_unzip ${BASE_IMAGE_URL} "${TRAIN_IMAGE_FILE}"
|
||||
TRAIN_IMAGE_DIR="${OUTPUT_DIR}/train${YEAR}"
|
||||
|
||||
VAL_IMAGE_FILE="val${YEAR}.zip"
|
||||
download_and_unzip ${BASE_IMAGE_URL} "${VAL_IMAGE_FILE}"
|
||||
VAL_IMAGE_DIR="${OUTPUT_DIR}/val${YEAR}"
|
||||
|
||||
COMMON_DIR="all$YEAR"
|
||||
mkdir -p "${COMMON_DIR}"
|
||||
for i in ${TRAIN_IMAGE_DIR}/*; do cp --symbolic-link "$i" ${COMMON_DIR}/; done
|
||||
for i in ${VAL_IMAGE_DIR}/*; do cp --symbolic-link "$i" ${COMMON_DIR}/; done
|
||||
|
||||
# Download the annotations.
|
||||
BASE_INSTANCES_URL="http://images.cocodataset.org/annotations"
|
||||
INSTANCES_FILE="annotations_trainval${YEAR}.zip"
|
||||
download_and_unzip ${BASE_INSTANCES_URL} "${INSTANCES_FILE}"
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,249 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from importlib import import_module
|
||||
from pyvww.utils import VisualWakeWords
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
best_acc = 0 # best test accuracy
|
||||
start_epoch = 0
|
||||
|
||||
#Arg parser
|
||||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
|
||||
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
|
||||
parser.add_argument('--epochs', default=900, type=int, help='total epochs')
|
||||
parser.add_argument('--resume', default=None, type=str, help='load from checkpoint')
|
||||
parser.add_argument('--model_arch',
|
||||
default='model_mobilenet_rnnpool', type=str,
|
||||
choices=['model_mobilenet_rnnpool', 'model_mobilenet_2rnnpool'],
|
||||
help='choose architecture among rpool variants')
|
||||
parser.add_argument('--ann', default=None, type=str,
|
||||
help='specify new-path-to-visualwakewords-dataset used in dataset creation step')
|
||||
parser.add_argument('--data', default=None, type=str,
|
||||
help='specify path-to-mscoco-dataset used in dataset creation step')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Data
|
||||
|
||||
class VisualWakeWordsClassification(VisionDataset):
|
||||
"""`Visual Wake Words <https://arxiv.org/abs/1906.05721>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory where COCO images are downloaded to.
|
||||
annFile (string): Path to json visual wake words annotation file.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.ToTensor``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
def __init__(self, root, annFile, transform=None, target_transform=None, split='val'):
|
||||
# super(VisualWakeWordsClassification, self).__init__(root, annFile, transform, target_transform, split)
|
||||
self.vww = VisualWakeWords(annFile)
|
||||
self.ids = list(sorted(self.vww.imgs.keys()))
|
||||
self.split = split
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.root = root
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: Tuple (image, target). target is the index of the target class.
|
||||
"""
|
||||
vww = self.vww
|
||||
img_id = self.ids[index]
|
||||
ann_ids = vww.getAnnIds(imgIds=img_id)
|
||||
target = vww.loadAnns(ann_ids)[0]['category_id']
|
||||
|
||||
path = vww.loadImgs(img_id)[0]['file_name']
|
||||
|
||||
img = Image.open(os.path.join(self.root, path)).convert('RGB')
|
||||
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
transform_train = transforms.Compose([
|
||||
# transforms.RandomAffine(10, translate=None, shear=(5,5,5,5), resample=False, fillcolor=0),
|
||||
transforms.RandomResizedCrop(size=(224,224), scale=(0.2,1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
#transforms.RandomAffine(10, translate=None, shear=(5,5,5,5), resample=False, fillcolor=0),
|
||||
# transforms.ColorJitter(brightness=(0.6,1.4), saturation=(0.9,1.1), hue=(-0.1,0.1)),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
trainset = VisualWakeWordsClassification(root=os.path.join(args.data,'all2014'),
|
||||
annFile=os.path.join(args.ann, 'annotations/instances_train.json'),
|
||||
transform=transform_train, split='train')
|
||||
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True,
|
||||
num_workers=32)
|
||||
|
||||
testset = VisualWakeWordsClassification(root=os.path.join(args.data,'all2014'),
|
||||
annFile=os.path.join(args.ann, 'annotations/instances_val.json'),
|
||||
transform=transform_test, split='val')
|
||||
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False,
|
||||
num_workers=32)
|
||||
|
||||
|
||||
# Model
|
||||
|
||||
module = import_module(args.model_arch)
|
||||
model = module.mobilenetv2_rnnpool(num_classes=2, width_mult=0.35, last_channel=320)
|
||||
model = model.to(device)
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
||||
|
||||
if args.resume:
|
||||
# Load checkpoint.
|
||||
print('==> Resuming from checkpoint..')
|
||||
assert os.path.isdir('./checkpoints/'), 'Error: no checkpoint directory found!'
|
||||
checkpoint = torch.load('./checkpoints/' + args.resume)
|
||||
best_acc = checkpoint['acc']
|
||||
start_epoch = checkpoint['epoch']
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=4e-5)#, alpha=0.9)
|
||||
|
||||
|
||||
# Training
|
||||
def train(epoch):
|
||||
print('\nEpoch: %d' % epoch)
|
||||
model.train()
|
||||
train_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
train_loader_len = len(trainloader)
|
||||
for batch_idx, (inputs, targets) in enumerate(trainloader):
|
||||
adjust_learning_rate(optimizer, epoch, batch_idx, train_loader_len)
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
|
||||
print('train_loss: ',train_loss/total, ' acc: ', correct/total)
|
||||
print('->>lr:{:.6f}'.format(optimizer.param_groups[0]['lr']))
|
||||
|
||||
def test(epoch):
|
||||
global best_acc
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (inputs, targets) in enumerate(testloader):
|
||||
batch_size = inputs.shape[0]
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
|
||||
print('test_loss: ',test_loss/total, ' test_acc: ', correct/total)
|
||||
|
||||
# Save checkpoint.
|
||||
print('best acc: ', best_acc)
|
||||
acc = 100.*correct/total
|
||||
if acc > best_acc:
|
||||
print('Saving..')
|
||||
state = {
|
||||
'model': model.state_dict(),
|
||||
'acc': acc,
|
||||
'epoch': epoch,
|
||||
}
|
||||
if not os.path.isdir('./checkpoints/'):
|
||||
os.mkdir('./checkpoints/')
|
||||
torch.save(state, './checkpoints/model_mobilenet_rnnpool.pth')
|
||||
best_acc = acc
|
||||
|
||||
|
||||
from math import cos, pi
|
||||
def adjust_learning_rate(optimizer, epoch, iteration, num_iter):
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
warmup_epoch = 0
|
||||
warmup_iter = warmup_epoch * num_iter
|
||||
current_iter = iteration + epoch * num_iter
|
||||
max_iter = 150 * num_iter
|
||||
|
||||
|
||||
lr = args.lr * (1 + cos(pi * (current_iter - warmup_iter) / (max_iter - warmup_iter))) / 2
|
||||
|
||||
if epoch < warmup_epoch:
|
||||
lr = args.lr * current_iter / warmup_iter
|
||||
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
for epoch in range(start_epoch, start_epoch+args.epochs):
|
||||
train(epoch)
|
||||
test(epoch)
|
|
@ -0,0 +1,18 @@
|
|||
# RNNPool quantized sample code
|
||||
|
||||
The `rnnpool_quantized.cpp` code takes the activations preceding the RNNpool layer
|
||||
and produces the output of a quantized RNN pool layer. The input numpy file consists
|
||||
of all activation patches corresponding to a single image. In `trace_0_input.npy`,
|
||||
there are 6241 patches of dimensions 8x8 with 4 channels to which RNNPool is applied.
|
||||
The output is of size 6241*4*8. This can be compared to the floatin point output stored in
|
||||
`trace_0_output.npy`
|
||||
|
||||
```shell
|
||||
g++ -o rnnpool_quantized rnnpool_quantized.cpp
|
||||
|
||||
# Usage: ./rnnpool_quantized <#patches> <input.npy> <output.npy>
|
||||
./rnnpool_quantized 6241 trace_0_input.npy trace_0_output_quantized.npy
|
||||
```
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// // Licensed under the MIT license
|
||||
|
||||
int16_t W1[4][8] = {{7069, 3262, 5513, 4733, -2708, -10109, 5233, 4489}
|
||||
, {-10390, -38, -17036, -404, -1288, -138, 226, -1100}
|
||||
, {1562, -1144, -14616, 4106, -18129, 2064, 831, 2845}
|
||||
, {-1993, -996, -6637, -1105, -1833, 1207, -1910, -1262}
|
||||
}; //16384
|
||||
|
||||
int16_t U1[8][8] = {{15238, -4081, -18973, -1468, 3401, 12650, -911, -1588}
|
||||
, {-1372, -2625, 23200, 5474, 7390, -3379, 5065, 7849}
|
||||
, {-931, -10160, -4142, -3773, 1400, 1952, -7027, -4937}
|
||||
, {-311, 3353, 10395, -410, 2437, 426, 5921, 4664}
|
||||
, {3195, -2369, -20748, -7006, 5303, -2544, -1009, -11564}
|
||||
, {-4775, 5477, -4431, 2161, 829, 18282, 1428, 3197}
|
||||
, {-435, 4946, 11025, 4571, 1986, -2559, -1213, 4943}
|
||||
, {16, 3484, 10337, 5800, 2855, 549, 5397, 561}
|
||||
}; //32768
|
||||
|
||||
int16_t Bg1[1][8] = {{-18778, -9519, 4055, -7310, 8584, -17258, -5281, -7934}
|
||||
}; //16384
|
||||
|
||||
int16_t Bh1[1][8] = {{9658, 19740, -10058, 19114, 17227, 12226, 19080, 15855}
|
||||
}; //32768
|
||||
|
||||
int16_t zeta1 = 32522; //32768
|
||||
|
||||
int16_t nu1 = 235; //32768
|
||||
|
||||
int16_t W2[8][8] = {{-850, 359, -9842, 5701, 7390, -4590, -3959, 2759}
|
||||
, {-1536, -6107, -1978, -5420, -1215, -5065, 77, -4658}
|
||||
, {10036, -340, 745, -3625, 1684, -1927, 2312, 2028}
|
||||
, {-3593, -1295, -997, -1, 1441, 2806, -1718, -3687}
|
||||
, {-287, -221, -1398, 439, -1651, 3409, -19972, -193}
|
||||
, {-6120, -4338, -1679, -9576, 13070, -12784, -56, -5648}
|
||||
, {-5623, -2853, -862, -3739, 2595, -285, -673, -5104}
|
||||
, {-3761, -842, -713, 396, 1405, 3339, -1477, -3670}
|
||||
}; //16384
|
||||
|
||||
int16_t U2[8][8] = {{8755, 2010, -3642, -913, 5998, -2312, -389, -1571}
|
||||
, {-906, 9661, -1875, -328, 4034, -3910, -355, -5117}
|
||||
, {-2433, 1688, 1328, -1493, 4122, 769, -177, 9988}
|
||||
, {-2759, 2240, 1795, 6117, 6542, -6011, 710, 283}
|
||||
, {-3163, 5634, 15468, -1189, 704, -1739, 483, 3409}
|
||||
, {-4224, 5383, -15324, -2616, 19957, 2042, -579, -319}
|
||||
, {181, -1085, 863, 1111, -4614, 4177, 3342, 4059}
|
||||
, {312, 996, -3600, -867, 2397, -1214, -917, 8633}
|
||||
}; //16384
|
||||
|
||||
int16_t Bg2[1][8] = {{-5411, -15415, -13003, -12122, -18931, -17923, -8693, -12151}
|
||||
}; //16384
|
||||
|
||||
int16_t Bh2[1][8] = {{21417, 6457, 6421, 8970, 6601, 836, 3060, 8468}
|
||||
}; //16384
|
||||
|
||||
int16_t zeta2 = 32520; //32768
|
||||
|
||||
int16_t nu2 = 256; //32768
|
||||
|
|
@ -0,0 +1,535 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
|
||||
#define MYINT int16_t
|
||||
#define MYITE int16_t
|
||||
|
||||
#include "data.h"
|
||||
|
||||
#define SHIFT
|
||||
|
||||
#ifdef SHIFT
|
||||
#define MYSCL int16_t
|
||||
unordered_map<string, MYSCL> scale = {
|
||||
{"X", 12},
|
||||
|
||||
{"one", 14},
|
||||
|
||||
{"W1",14},
|
||||
{"H1",14},
|
||||
{"U1",15},
|
||||
{"Bg1",14},
|
||||
{"Bh1",15},
|
||||
{"zeta1",15},
|
||||
{"nu1",15},
|
||||
|
||||
{"a1",11},
|
||||
{"b1",13},
|
||||
{"c1",11},
|
||||
{"cBg1",11},
|
||||
{"cBh1",11},
|
||||
{"g1",14},
|
||||
{"h1",14},
|
||||
{"z1",14},
|
||||
{"y1",14},
|
||||
{"w1",14},
|
||||
{"v1",14},
|
||||
{"u1",14},
|
||||
|
||||
{"intermediate", 14},
|
||||
|
||||
{"W2",14},
|
||||
{"H2",14},
|
||||
{"U2",14},
|
||||
{"Bg2",14},
|
||||
{"Bh2",14},
|
||||
{"zeta2",15},
|
||||
{"nu2",15},
|
||||
|
||||
{"a2",14},
|
||||
{"b2",13},
|
||||
{"c2",13},
|
||||
{"cBg2",11},
|
||||
{"cBh2",11},
|
||||
{"g2",14},
|
||||
{"h2",14},
|
||||
{"z2",15},
|
||||
{"y2",14},
|
||||
{"w2",14},
|
||||
{"v2",14},
|
||||
{"u2",14},
|
||||
|
||||
{"Y",14},
|
||||
};
|
||||
#else
|
||||
#define MYSCL int32_t
|
||||
unordered_map<string, MYSCL> scale = {
|
||||
{"X", 4096},
|
||||
|
||||
{"one", 16384},
|
||||
|
||||
{"W1",16384},
|
||||
{"H1",16384},
|
||||
{"U1",32768},
|
||||
{"Bg1",16384},
|
||||
{"Bh1",32768},
|
||||
{"zeta1",32768},
|
||||
{"nu1",32768},
|
||||
|
||||
{"a1",2048},
|
||||
{"b1",8192},
|
||||
{"c1",2048},
|
||||
{"cBg1",2048},
|
||||
{"cBh1",2048},
|
||||
{"g1",16384},
|
||||
{"h1",16384},
|
||||
{"z1",16384},
|
||||
{"y1",16384},
|
||||
{"w1",16384},
|
||||
{"v1",16384},
|
||||
{"u1",16384},
|
||||
|
||||
{"intermediate", 16384},
|
||||
|
||||
{"W2",16384},
|
||||
{"H2",16384},
|
||||
{"U2",16384},
|
||||
{"Bg2",16384},
|
||||
{"Bh2",16384},
|
||||
{"zeta2",32768},
|
||||
{"nu2",32768},
|
||||
|
||||
{"a2",16384},
|
||||
{"b2",8192},
|
||||
{"c2",8192},
|
||||
{"cBg2",2048},
|
||||
{"cBh2",2048},
|
||||
{"g2",16384},
|
||||
{"h2",16384},
|
||||
{"z2",32768},
|
||||
{"y2",16384},
|
||||
{"w2",16384},
|
||||
{"v2",16384},
|
||||
{"u2",16384},
|
||||
|
||||
{"Y",16384},
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
void MatMul(int16_t* A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYINT K, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
#ifdef SHIFT
|
||||
MYSCL addshrP = 1, addshr = 0;
|
||||
while (addshrP < J) {
|
||||
addshrP *= 2;
|
||||
addshr += 1;
|
||||
}
|
||||
#else
|
||||
MYSCL addshr = 1;
|
||||
while (addshr < J)
|
||||
addshr *= 2;
|
||||
#endif
|
||||
|
||||
#ifdef SHIFT
|
||||
MYSCL shr = scA + scB - scC - addshr;
|
||||
#else
|
||||
MYSCL shr = (scA * scB) / (scC * addshr);
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int k = 0; k < K; k++) {
|
||||
int32_t s = 0;
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
s += ((int32_t)A[i * J + j] * (int32_t)B[j * K + k]) >> addshr;
|
||||
#else
|
||||
s += ((int32_t)A[i * J + j] * (int32_t)B[j * K + k]) / addshr;
|
||||
#endif
|
||||
}
|
||||
#ifdef SHIFT
|
||||
C[i * K + k] = s >> shr;
|
||||
#else
|
||||
C[i * K + k] = s / shr;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline MYINT min(MYINT a, MYINT b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
inline MYINT max(MYINT a, MYINT b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
void MatAdd(int16_t* A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
MYSCL shrmin = min(scA, scB);
|
||||
#ifdef SHIFT
|
||||
MYSCL shra = scA - shrmin;
|
||||
MYSCL shrb = scB - shrmin;
|
||||
MYSCL shrc = shrmin - scC;
|
||||
#else
|
||||
MYSCL shra = scA / shrmin;
|
||||
MYSCL shrb = scB / shrmin;
|
||||
MYSCL shrc = shrmin / scC;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
C[i * J + j] = ((A[i * J + j] >> (shra + shrc)) + (B[i * J + j] >> (shrb + shrc)));
|
||||
#else
|
||||
C[i * J + j] = ((A[i * J + j] / (shra * shrc)) + (B[i * J + j] / (shrb * shrc)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScalarMatSub(int16_t A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
MYSCL shrmin = min(scA, scB);
|
||||
#ifdef SHIFT
|
||||
MYSCL shra = scA - shrmin;
|
||||
MYSCL shrb = scB - shrmin;
|
||||
MYSCL shrc = shrmin - scC;
|
||||
#else
|
||||
MYSCL shra = scA / shrmin;
|
||||
MYSCL shrb = scB / shrmin;
|
||||
MYSCL shrc = shrmin / scC;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
C[i * J + j] = ((A >> (shra + shrc)) - (B[i * J + j] >> (shrb + shrc)));
|
||||
#else
|
||||
C[i * J + j] = ((A / (shra * shrc)) - (B[i * J + j] / (shrb * shrc)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScalarMatAdd(int16_t A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
MYSCL shrmin = min(scA, scB);
|
||||
#ifdef SHIFT
|
||||
MYSCL shra = scA - shrmin;
|
||||
MYSCL shrb = scB - shrmin;
|
||||
MYSCL shrc = shrmin - scC;
|
||||
#else
|
||||
MYSCL shra = scA / shrmin;
|
||||
MYSCL shrb = scB / shrmin;
|
||||
MYSCL shrc = shrmin / scC;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
C[i * J + j] = ((A >> (shra + shrc)) + (B[i * J + j] >> (shrb + shrc)));
|
||||
#else
|
||||
C[i * J + j] = ((A / (shra * shrc)) + (B[i * J + j] / (shrb * shrc)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HadMul(int16_t* A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
#ifdef SHIFT
|
||||
MYSCL shr = (scA + scB) - scC;
|
||||
#else
|
||||
MYSCL shr = (scA * scB) / scC;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
C[i * J + j] = (((int32_t)A[i * J + j]) * ((int32_t)B[i * J + j])) >> shr;
|
||||
#else
|
||||
C[i * J + j] = (((int32_t)A[i * J + j]) * ((int32_t)B[i * J + j])) / shr;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScalarMul(int16_t A, int16_t* B, int16_t* C, MYINT I, MYINT J, MYSCL scA, MYSCL scB, MYSCL scC) {
|
||||
|
||||
#ifdef SHIFT
|
||||
MYSCL shr = (scA + scB) - scC;
|
||||
#else
|
||||
MYSCL shr = (scA * scB) / scC;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
#ifdef SHIFT
|
||||
C[i * J + j] = ((int32_t)(A) * (int32_t)(B[i * J + j])) >> shr;
|
||||
#else
|
||||
C[i * J + j] = ((int32_t)(A) * (int32_t)(B[i * J + j])) / shr;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SigmoidNew16(int16_t* A, MYINT I, MYINT J, int16_t* B) {
|
||||
for (MYITE i = 0; i < I; i++) {
|
||||
for (MYITE j = 0; j < J; j++) {
|
||||
int16_t a = A[i * J + j];
|
||||
B[i * J + j] = 8 * max(min((a + 2048) / 2, 2048), 0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void TanHNew16(int16_t* A, MYINT I, MYINT J, int16_t* B) {
|
||||
for (MYITE i = 0; i < I; i++) {
|
||||
for (MYITE j = 0; j < J; j++) {
|
||||
int16_t a = A[i * J + j];
|
||||
B[i * J + j] = 8 * max(min(a, 2048), -2048);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void reverse(int16_t* A, int16_t* B, int I, int J) {
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
B[i * J + j] = A[(I - i - 1) * J + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void print(int16_t* var, int I, int J, MYSCL scale) {
|
||||
for (int i = 0; i < I; i++) {
|
||||
for (int j = 0; j < J; j++) {
|
||||
cout << ((float)var[i * J + j]) / scale << " ";
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
//exit(1);
|
||||
}
|
||||
|
||||
void FastGRNN1(int16_t X[8][4], int16_t* H, int timestep) {
|
||||
memset(&H[0], 0, 8 * 2);
|
||||
|
||||
for (int i = 0; i < timestep; i++) {
|
||||
int16_t a[1][8];
|
||||
MatMul(&X[i][0], &W1[0][0], &a[0][0], 1, 4, 8, scale["X"], scale["W1"], scale["a1"]);
|
||||
int16_t b[1][8];
|
||||
MatMul(&H[0], &U1[0][0], &b[0][0], 1, 8, 8, scale["H1"], scale["U1"], scale["b1"]);
|
||||
int16_t c[1][8];
|
||||
MatAdd(&a[0][0], &b[0][0], &c[0][0], 1, 8, scale["a1"], scale["b1"], scale["c1"]);
|
||||
int16_t cBg[1][8];
|
||||
MatAdd(&c[0][0], &Bg1[0][0], &cBg[0][0], 1, 8, scale["c1"], scale["Bg1"], scale["cBg1"]);
|
||||
int16_t g[1][8];
|
||||
SigmoidNew16(&cBg[0][0], 1, 8, &g[0][0]);
|
||||
int16_t cBh[1][8];
|
||||
MatAdd(&c[0][0], &Bh1[0][0], &cBh[0][0], 1, 8, scale["c1"], scale["Bh1"], scale["cBh1"]);
|
||||
int16_t h[1][8];
|
||||
TanHNew16(&cBh[0][0], 1, 8, &h[0][0]);
|
||||
int16_t z[1][8];
|
||||
HadMul(&g[0][0], &H[0], &z[0][0], 1, 8, scale["g1"], scale["H1"], scale["z1"]);
|
||||
int16_t y[1][8];
|
||||
ScalarMatSub(16384, &g[0][0], &y[0][0], 1, 8, scale["one"], scale["g1"], scale["y1"]);
|
||||
int16_t w[1][8];
|
||||
ScalarMul(zeta1, &y[0][0], &w[0][0], 1, 8, scale["zeta1"], scale["y1"], scale["w1"]);
|
||||
int16_t v[1][8];
|
||||
ScalarMatAdd(nu1, &w[0][0], &v[0][0], 1, 8, scale["nu1"], scale["w1"], scale["v1"]);
|
||||
int16_t u[1][8];
|
||||
HadMul(&w[0][0], &h[0][0], &u[0][0], 1, 8, scale["w1"], scale["h1"], scale["u1"]);
|
||||
|
||||
MatAdd(&z[0][0], &u[0][0], &H[0], 1, 8, scale["z1"], scale["u1"], scale["H1"]);
|
||||
}
|
||||
}
|
||||
|
||||
void FastGRNN2(int16_t X[8][8], int16_t* H, int timestep) {
|
||||
memset(&H[0], 0, 8 * 2);
|
||||
|
||||
for (int i = 0; i < timestep; i++) {
|
||||
int16_t a[1][8];
|
||||
MatMul(&X[i][0], &W2[0][0], &a[0][0], 1, 8, 8, scale["intermediate"], scale["W2"], scale["a2"]);
|
||||
|
||||
int16_t b[1][8];
|
||||
MatMul(&H[0], &U2[0][0], &b[0][0], 1, 8, 8, scale["H2"], scale["U2"], scale["b2"]);
|
||||
int16_t c[1][8];
|
||||
MatAdd(&a[0][0], &b[0][0], &c[0][0], 1, 8, scale["a2"], scale["b2"], scale["c2"]);
|
||||
int16_t cBg[1][8];
|
||||
MatAdd(&c[0][0], &Bg2[0][0], &cBg[0][0], 1, 8, scale["c2"], scale["Bg2"], scale["cBg2"]);
|
||||
int16_t g[1][8];
|
||||
SigmoidNew16(&cBg[0][0], 1, 8, &g[0][0]);
|
||||
int16_t cBh[1][8];
|
||||
MatAdd(&c[0][0], &Bh2[0][0], &cBh[0][0], 1, 8, scale["c2"], scale["Bh2"], scale["cBh2"]);
|
||||
int16_t h[1][8];
|
||||
TanHNew16(&cBh[0][0], 1, 8, &h[0][0]);
|
||||
int16_t z[1][8];
|
||||
HadMul(&g[0][0], &H[0], &z[0][0], 1, 8, scale["g2"], scale["H2"], scale["z2"]);
|
||||
int16_t y[1][8];
|
||||
ScalarMatSub(16384, &g[0][0], &y[0][0], 1, 8, scale["one"], scale["g2"], scale["y2"]);
|
||||
int16_t w[1][8];
|
||||
ScalarMul(zeta2, &y[0][0], &w[0][0], 1, 8, scale["zeta2"], scale["y2"], scale["w2"]);
|
||||
int16_t v[1][8];
|
||||
ScalarMatAdd(nu2, &w[0][0], &v[0][0], 1, 8, scale["nu2"], scale["w2"], scale["v2"]);
|
||||
int16_t u[1][8];
|
||||
HadMul(&w[0][0], &h[0][0], &u[0][0], 1, 8, scale["w2"], scale["h2"], scale["u2"]);
|
||||
|
||||
MatAdd(&z[0][0], &u[0][0], &H[0], 1, 8, scale["z2"], scale["u2"], scale["H2"]);
|
||||
}
|
||||
}
|
||||
|
||||
void RNNPool(int16_t X[8][8][4], int16_t pred[1][32]) {
|
||||
|
||||
int16_t biinput1[8][8], biinput1r[8][8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int16_t subX[8][4];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
for (int k = 0; k < 4; k++) {
|
||||
subX[j][k] = X[i][j][k];
|
||||
}
|
||||
}
|
||||
int16_t H[1][8];
|
||||
FastGRNN1(subX, &H[0][0], 8);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
biinput1[i][j] = H[0][j];
|
||||
}
|
||||
}
|
||||
|
||||
int16_t res1[1][8], res2[1][8];
|
||||
FastGRNN2(biinput1, &res1[0][0], 8);
|
||||
reverse(&biinput1[0][0], &biinput1r[0][0], 8, 8);
|
||||
FastGRNN2(biinput1r, &res2[0][0], 8);
|
||||
|
||||
int16_t biinput2[8][8], biinput2r[8][8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int16_t subX[8][4];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
for (int k = 0; k < 4; k++) {
|
||||
subX[j][k] = X[j][i][k];
|
||||
}
|
||||
}
|
||||
int16_t H[1][8];
|
||||
FastGRNN1(subX, &H[0][0], 8);
|
||||
|
||||
for (int j = 0; j < 8; j++) {
|
||||
biinput2[i][j] = H[0][j];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int16_t res3[1][8], res4[1][8];
|
||||
FastGRNN2(biinput2, &res3[0][0], 8);
|
||||
reverse(&biinput2[0][0], &biinput2r[0][0], 8, 8);
|
||||
FastGRNN2(biinput2r, &res4[0][0], 8);
|
||||
|
||||
for (int i = 0; i < 8; i++)
|
||||
pred[0][i] = res1[0][i];
|
||||
for (int i = 0; i < 8; i++)
|
||||
pred[0][i + 8] = res2[0][i];
|
||||
for (int i = 0; i < 8; i++)
|
||||
pred[0][i + 16] = res3[0][i];
|
||||
for (int i = 0; i < 8; i++)
|
||||
pred[0][i + 24] = res4[0][i];
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
string inputfile, outputfile;
|
||||
int patches;
|
||||
if (argc != 4) {
|
||||
cerr << "Improper number of arguments" << endl;
|
||||
return -1;
|
||||
}
|
||||
else {
|
||||
patches = atoi(argv[1]);
|
||||
inputfile = string(argv[2]);
|
||||
outputfile = string(argv[3]);
|
||||
}
|
||||
|
||||
fstream Xfile, Yfile;
|
||||
|
||||
Xfile.open(inputfile, ios::in | ios::binary);
|
||||
Yfile.open(outputfile, ios::out | ios::binary);
|
||||
|
||||
|
||||
char line[8];
|
||||
Xfile.read(line, 8);
|
||||
int headerSize;
|
||||
Xfile.read((char*)&headerSize, 1 * 2);
|
||||
|
||||
char* headerLine = new char[headerSize]; //Ignored
|
||||
Xfile.read(headerLine, headerSize);
|
||||
delete[] headerLine;
|
||||
|
||||
char numpyMagix = 147;
|
||||
char numpyVersionMajor = 1, numpyVersionMinor = 0;
|
||||
string numpyMetaHeader = "";
|
||||
numpyMetaHeader += numpyMagix;
|
||||
numpyMetaHeader += "NUMPY";
|
||||
numpyMetaHeader += numpyVersionMajor;
|
||||
numpyMetaHeader += numpyVersionMinor;
|
||||
|
||||
string numpyHeader = "{'descr': '<f4', 'fortran_order': False, 'shape': (" + to_string(patches) + ", 1, 32), }";
|
||||
|
||||
for (int i = numpyHeader.size() + numpyMetaHeader.size() + 2; i % 64 != 64 - 1; i++) {
|
||||
numpyHeader += ' ';
|
||||
}
|
||||
numpyHeader += (char)(10);
|
||||
|
||||
char a = numpyHeader.size() / 256, b = numpyHeader.size() % 256;
|
||||
Yfile << numpyMetaHeader;
|
||||
Yfile << b << a;
|
||||
Yfile << numpyHeader;
|
||||
|
||||
int total = 0;
|
||||
int correct = 0;
|
||||
|
||||
for (int i = 0; i < 6241; i++) {
|
||||
|
||||
float Xline[256];
|
||||
Xfile.read((char*)&Xline[0], 256 * 4);
|
||||
|
||||
|
||||
int16_t y;
|
||||
int16_t reshapedX[8][8][4];
|
||||
|
||||
for (int a = 0; a < 4; a++) {
|
||||
for (int b = 0; b < 8; b++) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
#ifdef SHIFT
|
||||
reshapedX[b][c][a] = (int16_t)((Xline[a * 64 + b * 8 + c * 1]) * pow(2, scale["X"]));
|
||||
#else
|
||||
reshapedX[b][c][a] = (int16_t)((Xline[a * 64 + b * 8 + c * 1]) * scale["X"]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int16_t pred[1][32];
|
||||
RNNPool(reshapedX, pred);
|
||||
|
||||
for (int j = 0; j < 32; j++) {
|
||||
float val = ((float)pred[0][j]) / pow(2, scale["Y"]);
|
||||
Yfile.write((char*)&val, sizeof(float));
|
||||
}
|
||||
}
|
||||
Xfile.close();
|
||||
Yfile.close();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -144,8 +144,8 @@ class RNNCell(nn.Module):
|
|||
|
||||
def get_model_size(self):
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
mats = self.getVars()
|
||||
endW = self._num_W_matrices
|
||||
endU = endW + self._num_U_matrices
|
||||
|
@ -261,7 +261,7 @@ class FastGRNNCell(RNNCell):
|
|||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
|
||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
|
||||
|
||||
self.copy_previous_UW()
|
||||
# self.copy_previous_UW()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -330,7 +330,7 @@ class FastGRNNCUDACell(RNNCell):
|
|||
'''
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"):
|
||||
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity,
|
||||
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_nonlinearity, update_nonlinearity,
|
||||
1, 1, 2, wRank, uRank, wSparsity, uSparsity)
|
||||
if utils.findCUDA() is None:
|
||||
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
|
||||
|
@ -967,63 +967,115 @@ class BaseRNN(nn.Module):
|
|||
[batchSize, timeSteps, inputDims]
|
||||
'''
|
||||
|
||||
def __init__(self, cell: RNNCell, batch_first=False):
|
||||
def __init__(self, cell: RNNCell, batch_first=False, cell_reverse: RNNCell=None, bidirectional=False):
|
||||
super(BaseRNN, self).__init__()
|
||||
self._RNNCell = cell
|
||||
self.RNNCell = cell
|
||||
self._batch_first = batch_first
|
||||
self._bidirectional = bidirectional
|
||||
if cell_reverse is not None:
|
||||
self.RNNCell_reverse = cell_reverse
|
||||
elif self._bidirectional:
|
||||
self.RNNCell_reverse = cell
|
||||
|
||||
def getVars(self):
|
||||
return self._RNNCell.getVars()
|
||||
return self.RNNCell.getVars()
|
||||
|
||||
def forward(self, input, hiddenState=None,
|
||||
cellState=None):
|
||||
self.device = input.device
|
||||
self.num_directions = 2 if self._bidirectional else 1
|
||||
# hidden
|
||||
# for i in range(num_directions):
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
|
||||
if self._bidirectional:
|
||||
hiddenStates_reverse = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self.RNNCell_reverse.output_size]).to(self.device)
|
||||
|
||||
if hiddenState is None:
|
||||
hiddenState = torch.zeros(
|
||||
[input.shape[0] if self._batch_first else input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
[self.num_directions, input.shape[0] if self._batch_first else input.shape[1],
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
|
||||
if self._batch_first is True:
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
if self.RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
if self._bidirectional:
|
||||
cellStates_reverse = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self.RNNCell_reverse.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[0], self._RNNCell.output_size]).to(self.device)
|
||||
[self.num_directions, input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[:, i, :], (hiddenState, cellState))
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
cellStates[:, i, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
hiddenState[0], cellState[0] = self.RNNCell(
|
||||
input[:, i, :], (hiddenState[0].clone(), cellState[0].clone()))
|
||||
hiddenStates[:, i, :] = hiddenState[0]
|
||||
cellStates[:, i, :] = cellState[0]
|
||||
if self._bidirectional:
|
||||
hiddenState[1], cellState[1] = self.RNNCell_reverse(
|
||||
input[:, input.shape[1]-i-1, :], (hiddenState[1].clone(), cellState[1].clone()))
|
||||
hiddenStates_reverse[:, i, :] = hiddenState[1]
|
||||
cellStates_reverse[:, i, :] = cellState[1]
|
||||
if not self._bidirectional:
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1)
|
||||
else:
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
return hiddenStates
|
||||
hiddenState[0] = self.RNNCell(input[:, i, :], hiddenState[0].clone())
|
||||
hiddenStates[:, i, :] = hiddenState[0]
|
||||
if self._bidirectional:
|
||||
hiddenState[1] = self.RNNCell_reverse(
|
||||
input[:, input.shape[1]-i-1, :], hiddenState[1].clone())
|
||||
hiddenStates_reverse[:, i, :] = hiddenState[1]
|
||||
if not self._bidirectional:
|
||||
return hiddenStates
|
||||
else:
|
||||
return torch.cat([hiddenStates,hiddenStates_reverse],-1)
|
||||
else:
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
if self.RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
if self._bidirectional:
|
||||
cellStates_reverse = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self.RNNCell_reverse.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[1], self._RNNCell.output_size]).to(self.device)
|
||||
[self.num_directions, input.shape[1], self.RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[i, :, :], (hiddenState, cellState))
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
cellStates[i, :, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
hiddenState[0], cellState[0] = self.RNNCell(
|
||||
input[i, :, :], (hiddenState[0].clone(), cellState[0].clone()))
|
||||
hiddenStates[i, :, :] = hiddenState[0]
|
||||
cellStates[i, :, :] = cellState[0]
|
||||
if self._bidirectional:
|
||||
hiddenState[1], cellState[1] = self.RNNCell_reverse(
|
||||
input[input.shape[0]-i-1, :, :], (hiddenState[1].clone(), cellState[1].clone()))
|
||||
hiddenStates_reverse[i, :, :] = hiddenState[1]
|
||||
cellStates_reverse[i, :, :] = cellState[1]
|
||||
if not self._bidirectional:
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1)
|
||||
else:
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
return hiddenStates
|
||||
hiddenState[0] = self.RNNCell(input[i, :, :], hiddenState[0].clone())
|
||||
hiddenStates[i, :, :] = hiddenState[0]
|
||||
if self._bidirectional:
|
||||
hiddenState[1] = self.RNNCell_reverse(
|
||||
input[input.shape[0]-i-1, :, :], hiddenState[1].clone())
|
||||
hiddenStates_reverse[i, :, :] = hiddenState[1]
|
||||
if not self._bidirectional:
|
||||
return hiddenStates
|
||||
else:
|
||||
return torch.cat([hiddenStates,hiddenStates_reverse],-1)
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
@ -1031,14 +1083,26 @@ class LSTM(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False):
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False,
|
||||
bidirectional=False, is_shared_bidirectional=True):
|
||||
super(LSTM, self).__init__()
|
||||
self._bidirectional = bidirectional
|
||||
self._batch_first = batch_first
|
||||
self._is_shared_bidirectional = is_shared_bidirectional
|
||||
self.cell = LSTMLRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
if self._bidirectional is True and self._is_shared_bidirectional is False:
|
||||
self.cell_reverse = LSTMLRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
@ -1049,14 +1113,26 @@ class GRU(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False):
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False,
|
||||
bidirectional=False, is_shared_bidirectional=True):
|
||||
super(GRU, self).__init__()
|
||||
self._bidirectional = bidirectional
|
||||
self._batch_first = batch_first
|
||||
self._is_shared_bidirectional = is_shared_bidirectional
|
||||
self.cell = GRULRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
if self._bidirectional is True and self._is_shared_bidirectional is False:
|
||||
self.cell_reverse = GRULRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
@ -1067,14 +1143,26 @@ class UGRNN(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False):
|
||||
wSparsity=1.0, uSparsity=1.0, batch_first=False,
|
||||
bidirectional=False, is_shared_bidirectional=True):
|
||||
super(UGRNN, self).__init__()
|
||||
self._bidirectional = bidirectional
|
||||
self._batch_first = batch_first
|
||||
self._is_shared_bidirectional = is_shared_bidirectional
|
||||
self.cell = UGRNNLRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
if self._bidirectional is True and self._is_shared_bidirectional is False:
|
||||
self.cell_reverse = UGRNNLRCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity)
|
||||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
@ -1085,15 +1173,28 @@ class FastRNN(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, batch_first=False):
|
||||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
|
||||
batch_first=False, bidirectional=False, is_shared_bidirectional=True):
|
||||
super(FastRNN, self).__init__()
|
||||
self._bidirectional = bidirectional
|
||||
self._batch_first = batch_first
|
||||
self._is_shared_bidirectional = is_shared_bidirectional
|
||||
self.cell = FastRNNCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity,
|
||||
alphaInit=alphaInit, betaInit=betaInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
if self._bidirectional is True and self._is_shared_bidirectional is False:
|
||||
self.cell_reverse = FastRNNCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity,
|
||||
alphaInit=alphaInit, betaInit=betaInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
@ -1105,15 +1206,27 @@ class FastGRNN(nn.Module):
|
|||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
|
||||
batch_first=False):
|
||||
batch_first=False, bidirectional=False, is_shared_bidirectional=True):
|
||||
super(FastGRNN, self).__init__()
|
||||
self._bidirectional = bidirectional
|
||||
self._batch_first = batch_first
|
||||
self._is_shared_bidirectional = is_shared_bidirectional
|
||||
self.cell = FastGRNNCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity,
|
||||
zetaInit=zetaInit, nuInit=nuInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
if self._bidirectional is True and self._is_shared_bidirectional is False:
|
||||
self.cell_reverse = FastGRNNCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity,
|
||||
zetaInit=zetaInit, nuInit=nuInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional)
|
||||
|
||||
def getVars(self):
|
||||
return self.unrollRNN.getVars()
|
||||
|
@ -1222,8 +1335,8 @@ class FastGRNNCUDA(nn.Module):
|
|||
|
||||
def get_model_size(self):
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
mats = self.getVars()
|
||||
endW = self._num_W_matrices
|
||||
endU = endW + self._num_U_matrices
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from edgeml_pytorch.graph.rnn import *
|
||||
|
||||
class RNNPool(nn.Module):
|
||||
def __init__(self, nRows, nCols, nHiddenDims,
|
||||
nHiddenDimsBiDir, inputDims):
|
||||
super(RNNPool, self).__init__()
|
||||
self.nRows = nRows
|
||||
self.nCols = nCols
|
||||
self.inputDims = inputDims
|
||||
self.nHiddenDims = nHiddenDims
|
||||
self.nHiddenDimsBiDir = nHiddenDimsBiDir
|
||||
|
||||
self._build()
|
||||
|
||||
def _build(self):
|
||||
|
||||
self.cell_rnn = FastGRNN(self.inputDims, self.nHiddenDims, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", zetaInit=100.0, nuInit=-100.0,
|
||||
batch_first=False, bidirectional=False)
|
||||
|
||||
self.cell_bidirrnn = FastGRNN(self.nHiddenDims, self.nHiddenDimsBiDir, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", zetaInit=100.0, nuInit=-100.0,
|
||||
batch_first=False, bidirectional=True, is_shared_bidirectional=True)
|
||||
|
||||
|
||||
def static_single(self,inputs, hidden, batch_size):
|
||||
|
||||
outputs = self.cell_rnn(inputs, hidden[0], hidden[1])
|
||||
return torch.split(outputs[-1], split_size_or_sections=batch_size, dim=0)
|
||||
|
||||
def forward(self,inputs,batch_size):
|
||||
## across rows
|
||||
|
||||
row_timestack = torch.cat(torch.unbind(inputs, dim=3),dim=0)
|
||||
|
||||
stateList = self.static_single(torch.stack(torch.unbind(row_timestack,dim=2)),
|
||||
(torch.zeros(1, batch_size * self.nRows, self.nHiddenDims).to(torch.device("cuda")),
|
||||
torch.zeros(1, batch_size * self.nRows, self.nHiddenDims).to(torch.device("cuda"))),batch_size)
|
||||
|
||||
outputs_cols = self.cell_bidirrnn(torch.stack(stateList),
|
||||
torch.zeros(2, batch_size, self.nHiddenDimsBiDir).to(torch.device("cuda")),
|
||||
torch.zeros(2, batch_size, self.nHiddenDimsBiDir).to(torch.device("cuda")))
|
||||
|
||||
|
||||
## across columns
|
||||
col_timestack = torch.cat(torch.unbind(inputs, dim=2),dim=0)
|
||||
|
||||
stateList = self.static_single(torch.stack(torch.unbind(col_timestack,dim=2)),
|
||||
(torch.zeros(1, batch_size * self.nRows, self.nHiddenDims).to(torch.device("cuda")),
|
||||
torch.zeros(1, batch_size * self.nRows, self.nHiddenDims).to(torch.device("cuda"))),batch_size)
|
||||
|
||||
outputs_rows = self.cell_bidirrnn(torch.stack(stateList),
|
||||
torch.zeros(2, batch_size, self.nHiddenDimsBiDir).to(torch.device("cuda")),
|
||||
torch.zeros(2, batch_size, self.nHiddenDimsBiDir).to(torch.device("cuda")))
|
||||
|
||||
|
||||
|
||||
output = torch.cat([outputs_rows[-1],outputs_cols[-1]],1)
|
||||
|
||||
return output
|
Загрузка…
Ссылка в новой задаче