restore i3d + itegration test (#562)
* restore i3d + itegration test * pipeline tiemout update * fix * action rec integratino * black
This commit is contained in:
Родитель
cb51b0fb9e
Коммит
af63f8780a
|
@ -54,7 +54,7 @@ The following is a summary of commonly used Computer Vision scenarios that are c
|
|||
| [Detection](scenarios/detection) | Base | Object Detection is a technique that allows you to detect the bounding box of an object within an image. |
|
||||
| [Keypoints](scenarios/keypoints) | Base | Keypoint detection can be used to detect specific points on an object. A pre-trained model is provided to detect body joints for human pose estimation. |
|
||||
| [Segmentation](scenarios/segmentation) | Base | Image Segmentation assigns a category to each pixel in an image. |
|
||||
| [Action recognition](scenarios/action_recognition) | Base | Action recognition to identify in video/webcam footage what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times.|
|
||||
| [Action recognition](scenarios/action_recognition) | Base | Action recognition to identify in video/webcam footage what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times. We also implemented the i3d implementation of action recognition that can be found under (contrib)[contrib]. |
|
||||
| [Crowd counting](contrib/crowd_counting) | Contrib | Counting the number of people in low-crowd-density (e.g. less than 10 people) and high-crowd-density (e.g. thousands of people) scenarios.|
|
||||
|
||||
We separate the supported CV scenarios into two locations: (i) **base**: code and notebooks within the "utils_cv" and "scenarios" folders which follow strict coding guidelines, are well tested and maintained; (ii) **contrib**: code and other assets within the "contrib" folder, mainly covering less common CV scenarios using bleeding edge state-of-the-art approaches. Code in "contrib" is not regularly tested or maintained.
|
||||
|
|
|
@ -9,6 +9,7 @@ Each project should live in its own subdirectory ```/contrib/<project>``` and co
|
|||
| Directory | Project description | Build status (optional) |
|
||||
|---|---|---|
|
||||
| [Crowd counting](crowd_counting) | Counting the number of people in low-crowd-density (e.g. less than 10 people) and high-crowd-density (e.g. thousands of people) scenarios. | [![Build Status](https://dev.azure.com/team-sharat/crowd-counting/_apis/build/status/lixzhang.cnt?branchName=lixzhang%2Fsubmodule-rev3)](https://dev.azure.com/team-sharat/crowd-counting/_build/latest?definitionId=49&branchName=lixzhang%2Fsubmodule-rev3)|
|
||||
| [Action Recognition with I3D](action_recognition) | Action recognition to identify video/webcam footage from what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times. Please note, that we also have a R(2+1)D implementation of action recognition that you can find under [scenarios](../sceanrios).| |
|
||||
|
||||
## Tools
|
||||
| Directory | Project description | Build status (optional) |
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Action Recognition
|
||||
|
||||
This directory contains resources for building video-based action recognition systems.
|
||||
|
||||
Action recognition (also known as activity recognition) consists of classifying various actions from a sequence of frames:
|
||||
|
||||
![](./media/action_recognition2.gif "Example of action recognition")
|
||||
|
||||
We implemented two state-of-the-art approaches: (i) [I3D](https://arxiv.org/pdf/1705.07750.pdf) and (ii) [R(2+1)D](https://arxiv.org/abs/1711.11248). This includes example notebooks for e.g. scoring of webcam footage or fine-tuning on the [HMDB-51](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) dataset. The latter can be accessed under [scenarios](../scenarios) at the root level.
|
||||
|
||||
We recommend to use the **R(2+1)D** model for its competitive accuracy, fast inference speed, and less dependencies on other packages. For both approaches, using our implementations, we were able to reproduce reported accuracies:
|
||||
|
||||
| Model | Reported in the paper | Our results |
|
||||
| ------- | -------| ------- |
|
||||
| R(2+1)D-34 RGB | 79.6% | 79.8% |
|
||||
| I3D RGB | 74.8% | 73.7% |
|
||||
| I3D Optical flow | 77.1% | 77.5% |
|
||||
| I3D Two-Stream | 80.7% | 81.2% |
|
||||
|
||||
|
||||
## Projects
|
||||
|
||||
| Directory | Description |
|
||||
| -------- | ----------- |
|
||||
| [i3d](i3d) | Scripts for fine-tuning a pre-trained I3D model on HMDB-51
|
||||
dataset. |
|
|
@ -0,0 +1,7 @@
|
|||
__pycache__/
|
||||
models/__pycache__/
|
||||
log/
|
||||
.vscode/
|
||||
checkpoints/
|
||||
pretrained_models/
|
||||
inference/.ipynb_checkpoints/
|
|
@ -0,0 +1,61 @@
|
|||
## Fine-tuning I3D model on HMDB-51
|
||||
|
||||
In this section we provide code for training a Two-Stream Inflated 3D ConvNet (I3D), introduced in \[[1](https://arxiv.org/pdf/1705.07750.pdf)\]. Our implementation uses the Pytorch models (and code) provided in [https://github.com/piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d) - which have been pre-trained on the Kinetics Human Action Video dataset - and fine-tunes the models on the HMDB-51 action recognition dataset. The I3D model consists of two "streams" which are independently trained models. One stream takes the RGB image frames from videos as input and the other stream takes pre-computed optical flow as input. At test time, the outputs of each stream model are averaged to make the final prediction. The model results are as follows:
|
||||
|
||||
| Model | Paper top 1 accuracy (average over 3 splits) | Our models top 1 accuracy (split 1 only) |
|
||||
| ------- | -------| ------- |
|
||||
| RGB | 74.8 | 73.7 |
|
||||
| Optical flow | 77.1 | 77.5 |
|
||||
| Two-Stream | 80.7 | 81.2 |
|
||||
|
||||
## Download and pre-process HMDB-51 data
|
||||
|
||||
Download the HMDB-51 video database from [here](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). Extract the videos with
|
||||
```
|
||||
mkdir rars && mkdir videos
|
||||
unrar x hmdb51-org.rar rars/
|
||||
for a in $(ls rars); do unrar x "rars/${a}" videos/; done;
|
||||
```
|
||||
|
||||
Use code provided in [https://github.com/yjxiong/temporal-segment-networks](https://github.com/yjxiong/temporal-segment-networks) to preprocess the raw videos into split videos into RGB frames and compute optical flow frames:
|
||||
```
|
||||
git clone https://github.com/yjxiong/temporal-segment-networks
|
||||
cd temporal-segment-networks
|
||||
bash scripts/extract_optical_flow.sh /path/to/hmdb51/videos /path/to/rawframes/output
|
||||
```
|
||||
Edit the _C.DATASET.DIR option in [default.py](default.py) to point towards the rawframes input data directory.
|
||||
|
||||
## Setup environment
|
||||
Setup environment
|
||||
|
||||
```
|
||||
conda env create -f environment.yaml
|
||||
conda activate i3d
|
||||
```
|
||||
|
||||
## Download pretrained models
|
||||
Download pretrained models
|
||||
|
||||
```
|
||||
bash download_models.sh
|
||||
```
|
||||
|
||||
## Fine-tune pretrained models on HMDB-51
|
||||
|
||||
Train RGB model
|
||||
```
|
||||
python train.py --cfg config/train_rgb.yaml
|
||||
```
|
||||
|
||||
Train flow model
|
||||
```
|
||||
python train.py --cfg config/train_flow.yaml
|
||||
```
|
||||
|
||||
Evaluate combined model
|
||||
```
|
||||
python test.py
|
||||
```
|
||||
|
||||
\[1\] J. Carreira and A. Zisserman. Quo vadis, action recognition?
|
||||
a new model and the kinetics dataset. In CVPR, 2017.
|
|
@ -0,0 +1,4 @@
|
|||
MODEL:
|
||||
NAME: "i3d_flow"
|
||||
TRAIN:
|
||||
MODALITY: "flow"
|
|
@ -0,0 +1,4 @@
|
|||
MODEL:
|
||||
NAME: "i3d_rgb"
|
||||
TRAIN:
|
||||
MODALITY: "RGB"
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Adapted from https://github.com/feiyunzhang/i3d-non-local-pytorch/blob/master/dataset.py
|
||||
|
||||
import torch.utils.data as data
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
from numpy.random import randint
|
||||
from pathlib import Path
|
||||
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
from videotransforms import (
|
||||
GroupRandomCrop, GroupRandomHorizontalFlip,
|
||||
GroupScale, GroupCenterCrop, GroupNormalize, Stack
|
||||
)
|
||||
|
||||
from itertools import cycle
|
||||
|
||||
|
||||
class VideoRecord(object):
|
||||
def __init__(self, row):
|
||||
self._data = row
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._data[0]
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
return int(
|
||||
len([x for x in Path(
|
||||
self._data[0]).glob('img_*')])-1)
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return int(self._data[1])
|
||||
|
||||
|
||||
class I3DDataSet(data.Dataset):
|
||||
def __init__(self, data_root, split=1, sample_frames=64,
|
||||
modality='RGB', transform=lambda x:x,
|
||||
train_mode=True, sample_frames_at_test=False):
|
||||
|
||||
self.data_root = data_root
|
||||
self.split = split
|
||||
self.sample_frames = sample_frames
|
||||
self.modality = modality
|
||||
self.transform = transform
|
||||
self.train_mode = train_mode
|
||||
self.sample_frames_at_test = sample_frames_at_test
|
||||
|
||||
self._parse_split_files()
|
||||
|
||||
|
||||
def _parse_split_files(self):
|
||||
# class labels assigned by sorting the file names in /data/hmdb51_splits directory
|
||||
file_list = sorted(Path('./data/hmdb51_splits').glob('*'+str(self.split)+'.txt'))
|
||||
video_list = []
|
||||
for class_idx, f in enumerate(file_list):
|
||||
class_name = str(f).strip().split('/')[2][:-16]
|
||||
for line in open(f):
|
||||
tokens = line.strip().split(' ')
|
||||
video_path = self.data_root+class_name+'/'+tokens[0][:-4]
|
||||
record = (video_path, class_idx)
|
||||
# 1 indicates video should be in training set
|
||||
if self.train_mode & (tokens[-1] == '1'):
|
||||
video_list.append(VideoRecord(record))
|
||||
# 2 indicates video should be in test set
|
||||
elif (self.train_mode == False) & (tokens[-1] == '2'):
|
||||
video_list.append(VideoRecord(record))
|
||||
|
||||
self.video_list = video_list
|
||||
|
||||
|
||||
def _load_image(self, directory, idx):
|
||||
if self.modality == 'RGB':
|
||||
img_path = os.path.join(directory, 'img_{:05}.jpg'.format(idx))
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
except:
|
||||
print("Couldn't load image:{}".format(img_path))
|
||||
return None
|
||||
return img
|
||||
else:
|
||||
try:
|
||||
img_path = os.path.join(directory, 'flow_x_{:05}.jpg'.format(idx))
|
||||
x_img = Image.open(img_path).convert('L')
|
||||
except:
|
||||
print("Couldn't load image:{}".format(img_path))
|
||||
return None
|
||||
try:
|
||||
img_path = os.path.join(directory, 'flow_y_{:05}.jpg'.format(idx))
|
||||
y_img = Image.open(img_path).convert('L')
|
||||
except:
|
||||
print("Couldn't load image:{}".format(img_path))
|
||||
return None
|
||||
# Combine flow images into single PIL image
|
||||
x_img = np.array(x_img, dtype=np.float32)
|
||||
y_img = np.array(y_img, dtype=np.float32)
|
||||
img = np.asarray([x_img, y_img]).transpose([1, 2, 0])
|
||||
img = Image.fromarray(img.astype('uint8'))
|
||||
return img
|
||||
|
||||
|
||||
def _sample_indices(self, record):
|
||||
if record.num_frames > self.sample_frames:
|
||||
start_pos = randint(record.num_frames - self.sample_frames + 1)
|
||||
indices = range(start_pos, start_pos + self.sample_frames, 1)
|
||||
else:
|
||||
indices = [x for x in range(record.num_frames)]
|
||||
if len(indices) < self.sample_frames:
|
||||
self._loop_indices(indices)
|
||||
return indices
|
||||
|
||||
|
||||
def _loop_indices(self, indices):
|
||||
indices_cycle = cycle(indices)
|
||||
while len(indices) < self.sample_frames:
|
||||
indices.append(next(indices_cycle))
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
record = self.video_list[index]
|
||||
# Sample frames from the the video for training, or if sampling
|
||||
# turned on at test time
|
||||
if self.train_mode or self.sample_frames_at_test:
|
||||
segment_indices = self._sample_indices(record)
|
||||
else:
|
||||
segment_indices = [i for i in range(record.num_frames)]
|
||||
# Image files are 1-indexed
|
||||
segment_indices = [i+1 for i in segment_indices]
|
||||
# Get video frame images
|
||||
images = []
|
||||
for i in segment_indices:
|
||||
seg_img = self._load_image(record.path, i)
|
||||
if seg_img is None:
|
||||
raise ValueError("Couldn't load", record.path, i)
|
||||
images.append(seg_img)
|
||||
# Apply transformations
|
||||
transformed_images = self.transform(images)
|
||||
|
||||
return transformed_images, record.label
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_list)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
input_size = 224
|
||||
resize_small_edge = 256
|
||||
|
||||
train_rgb = I3DDataSet(
|
||||
data_root='/datadir/rawframes/',
|
||||
split=1,
|
||||
sample_frames = 64,
|
||||
modality='RGB',
|
||||
train_mode=True,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(resize_small_edge),
|
||||
GroupRandomCrop(input_size),
|
||||
GroupRandomHorizontalFlip(),
|
||||
GroupNormalize(modality="RGB"),
|
||||
Stack(),
|
||||
])
|
||||
)
|
||||
item = train_rgb.__getitem__(10)
|
||||
print("train_rgb:")
|
||||
print(item[0].size())
|
||||
print("max=", item[0].max())
|
||||
print("min=", item[0].min())
|
||||
print("label=",item[1])
|
||||
|
||||
val_rgb = I3DDataSet(
|
||||
data_root='/datadir/rawframes/',
|
||||
split=1,
|
||||
sample_frames = 64,
|
||||
modality='RGB',
|
||||
train_mode=False,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(resize_small_edge),
|
||||
GroupCenterCrop(input_size),
|
||||
GroupNormalize(modality="RGB"),
|
||||
Stack(),
|
||||
])
|
||||
)
|
||||
item = val_rgb.__getitem__(10)
|
||||
print("val_rgb:")
|
||||
print(item[0].size())
|
||||
print("max=", item[0].max())
|
||||
print("min=", item[0].min())
|
||||
print("label=",item[1])
|
||||
|
||||
train_flow = I3DDataSet(
|
||||
data_root='/datadir/rawframes/',
|
||||
split=1,
|
||||
sample_frames = 64,
|
||||
modality='flow',
|
||||
train_mode=True,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(resize_small_edge),
|
||||
GroupRandomCrop(input_size),
|
||||
GroupRandomHorizontalFlip(),
|
||||
GroupNormalize(modality="flow"),
|
||||
Stack(),
|
||||
])
|
||||
)
|
||||
item = train_flow.__getitem__(100)
|
||||
print("train_flow:")
|
||||
print(item[0].size())
|
||||
print("max=", item[0].max())
|
||||
print("min=", item[0].min())
|
||||
print("label=",item[1])
|
||||
|
||||
val_flow = I3DDataSet(
|
||||
data_root='/datadir/rawframes/',
|
||||
split=1,
|
||||
sample_frames = 64,
|
||||
modality='flow',
|
||||
train_mode=False,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(resize_small_edge),
|
||||
GroupCenterCrop(input_size),
|
||||
GroupNormalize(modality="flow"),
|
||||
Stack(),
|
||||
])
|
||||
)
|
||||
item = val_flow.__getitem__(100)
|
||||
print("val_flow:")
|
||||
print(item[0].size())
|
||||
print("max=", item[0].max())
|
||||
print("min=", item[0].min())
|
||||
print("label=",item[1])
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
|
||||
_C = CN()
|
||||
|
||||
_C.LOG_DIR = "log"
|
||||
_C.WORKERS = 16
|
||||
_C.PIN_MEMORY = True
|
||||
_C.SEED = 42
|
||||
|
||||
# Cudnn related params
|
||||
_C.CUDNN = CN()
|
||||
_C.CUDNN.BENCHMARK = True
|
||||
|
||||
# Dataset
|
||||
_C.DATASET = CN()
|
||||
_C.DATASET.SPLIT = 1
|
||||
_C.DATASET.DIR = "/datadir/rawframes/"
|
||||
_C.DATASET.NUM_CLASSES = 51
|
||||
|
||||
# NETWORK
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.NAME = "i3d_flow"
|
||||
_C.MODEL.PRETRAINED_RGB = "pretrained_models/rgb_imagenet_kinetics.pt"
|
||||
_C.MODEL.PRETRAINED_FLOW = "pretrained_models/flow_imagenet_kinetics.pt"
|
||||
_C.MODEL.CHECKPOINT_DIR = "checkpoints"
|
||||
|
||||
# Train
|
||||
_C.TRAIN = CN()
|
||||
_C.TRAIN.PRINT_FREQ = 50
|
||||
_C.TRAIN.INPUT_SIZE = 224
|
||||
_C.TRAIN.RESIZE_MIN = 256
|
||||
_C.TRAIN.SAMPLE_FRAMES = 64
|
||||
_C.TRAIN.MODALITY = "flow"
|
||||
_C.TRAIN.BATCH_SIZE = 24
|
||||
_C.TRAIN.GRAD_ACCUM_STEPS = 4
|
||||
_C.TRAIN.MAX_EPOCHS = 50
|
||||
|
||||
# Test
|
||||
_C.TEST = CN()
|
||||
_C.TEST.EVAL_FREQ = 5
|
||||
_C.TEST.PRINT_FREQ = 250
|
||||
_C.TEST.BATCH_SIZE = 1
|
||||
_C.TEST.MODALITY = "combined"
|
||||
_C.TEST.MODEL_RGB = "pretrained_models/rgb_hmdb_split1.pt"
|
||||
_C.TEST.MODEL_FLOW = "pretrained_models/flow_hmdb_split1.pt"
|
||||
|
||||
def update_config(cfg, options=None, config_file=None):
|
||||
cfg.defrost()
|
||||
|
||||
if config_file:
|
||||
cfg.merge_from_file(config_file)
|
||||
|
||||
if options:
|
||||
cfg.merge_from_list(options)
|
||||
|
||||
cfg.freeze()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
with open(sys.argv[1], "w") as f:
|
||||
print(_C, file=f)
|
|
@ -0,0 +1,10 @@
|
|||
#!/usr/bin/env bash
|
||||
wget https://har.blob.core.windows.net/i3dmodels/flow_hmdb_split1.pt
|
||||
wget https://har.blob.core.windows.net/i3dmodels/rgb_hmdb_split1.pt
|
||||
wget https://har.blob.core.windows.net/i3dmodels/flow_imagenet_kinetics.pt
|
||||
wget https://har.blob.core.windows.net/i3dmodels/rgb_imagenet_kinetics.pt
|
||||
|
||||
mv flow_hmdb_split1.pt pretrained_models/flow_hmdb_split1.pt
|
||||
mv rgb_hmdb_split1.pt pretrained_models/rgb_hmdb_split1.pt
|
||||
mv flow_imagenet_kinetics.pt pretrained_models/flow_imagenet_kinetics.pt
|
||||
mv rgb_imagenet_kinetics.pt pretrained_models/rgb_imagenet_kinetics.pt
|
|
@ -0,0 +1,20 @@
|
|||
name: i3d
|
||||
dependencies:
|
||||
- python=3.6.2
|
||||
- pandas
|
||||
- numpy
|
||||
- ipykernel
|
||||
- matplotlib
|
||||
- pip:
|
||||
- torch==1.2.0
|
||||
- torchvision
|
||||
- pillow
|
||||
- fire
|
||||
- tensorboardX
|
||||
- tensorboard
|
||||
- yacs
|
||||
- opencv-contrib-python-headless
|
||||
|
||||
channels:
|
||||
- conda-forge
|
||||
- anaconda
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from videotransforms import (
|
||||
GroupScale, GroupCenterCrop, GroupNormalize, Stack
|
||||
)
|
||||
from models.pytorch_i3d import InceptionI3d
|
||||
from dataset import I3DDataSet
|
||||
from test import load_model
|
||||
|
||||
|
||||
def load_image(frame_file):
|
||||
try:
|
||||
img = Image.open(frame_file).convert('RGB')
|
||||
return img
|
||||
except:
|
||||
print("Couldn't load image:{}".format(frame_file))
|
||||
return None
|
||||
|
||||
|
||||
def load_frames(frame_paths):
|
||||
frame_list = []
|
||||
for frame in frame_paths:
|
||||
frame_list.append(load_image(frame))
|
||||
return frame_list
|
||||
|
||||
|
||||
def construct_input(frame_list):
|
||||
|
||||
transform = torchvision.transforms.Compose([
|
||||
GroupScale(config.TRAIN.RESIZE_MIN),
|
||||
GroupCenterCrop(config.TRAIN.INPUT_SIZE),
|
||||
GroupNormalize(modality="RGB"),
|
||||
Stack(),
|
||||
])
|
||||
|
||||
process_data = transform(frame_list)
|
||||
return process_data.unsqueeze(0)
|
||||
|
||||
|
||||
def predict_input(model, input):
|
||||
input = input.cuda(non_blocking=True)
|
||||
output = model(input)
|
||||
output = torch.mean(output, dim=2)
|
||||
return output
|
||||
|
||||
|
||||
def predict_over_video(video_frame_list, window_width=9, stride=1):
|
||||
|
||||
if window_width < 9:
|
||||
raise ValueError("window_width must be 9 or greater")
|
||||
|
||||
print("Loading model...")
|
||||
|
||||
model = load_model(
|
||||
modality="RGB",
|
||||
state_dict_file="pretrained_chkpt/rgb_hmdb_split1.pt"
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
print("Predicting actions over {0} frames".format(len(video_frame_list)))
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
window_count = 0
|
||||
|
||||
for i in range(stride+window_width-1, len(video_frame_list), stride):
|
||||
window_frame_list = [video_frame_list[j] for j in range(i-window_width, i)]
|
||||
frames = load_frames(window_frame_list)
|
||||
batch = construct_input(frames)
|
||||
window_predictions = predict_input(model, batch)
|
||||
window_proba = F.softmax(window_predictions, dim=1)
|
||||
window_top_pred = window_proba.max(1)
|
||||
print(("Window:{0} Class pred:{1} Class proba:{2}".format(
|
||||
window_count,
|
||||
window_top_pred.indices.cpu().numpy()[0],
|
||||
window_top_pred.values.cpu().numpy()[0])
|
||||
))
|
||||
window_count += 1
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Provide list of filepaths to video frames
|
||||
frame_paths = []
|
||||
|
||||
predict_over_video(frame_list, window_width=64, stride=32)
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# From https://github.com/feiyunzhang/i3d-non-local-pytorch/blob/master/main.py
|
||||
|
||||
import torch
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
|
@ -0,0 +1,338 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MaxPool3dSamePadding(nn.MaxPool3d):
|
||||
|
||||
def compute_pad(self, dim, s):
|
||||
if s % self.stride[dim] == 0:
|
||||
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
||||
else:
|
||||
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
||||
|
||||
def forward(self, x):
|
||||
# compute 'same' padding
|
||||
(batch, channel, t, h, w) = x.size()
|
||||
#print t,h,w
|
||||
out_t = np.ceil(float(t) / float(self.stride[0]))
|
||||
out_h = np.ceil(float(h) / float(self.stride[1]))
|
||||
out_w = np.ceil(float(w) / float(self.stride[2]))
|
||||
#print out_t, out_h, out_w
|
||||
pad_t = self.compute_pad(0, t)
|
||||
pad_h = self.compute_pad(1, h)
|
||||
pad_w = self.compute_pad(2, w)
|
||||
#print pad_t, pad_h, pad_w
|
||||
|
||||
pad_t_f = pad_t // 2
|
||||
pad_t_b = pad_t - pad_t_f
|
||||
pad_h_f = pad_h // 2
|
||||
pad_h_b = pad_h - pad_h_f
|
||||
pad_w_f = pad_w // 2
|
||||
pad_w_b = pad_w - pad_w_f
|
||||
|
||||
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
||||
#print x.size()
|
||||
#print pad
|
||||
x = F.pad(x, pad)
|
||||
return super(MaxPool3dSamePadding, self).forward(x)
|
||||
|
||||
|
||||
class Unit3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels,
|
||||
output_channels,
|
||||
kernel_shape=(1, 1, 1),
|
||||
stride=(1, 1, 1),
|
||||
padding=0,
|
||||
activation_fn=F.relu,
|
||||
use_batch_norm=True,
|
||||
use_bias=False,
|
||||
name='unit_3d'):
|
||||
|
||||
"""Initializes Unit3D module."""
|
||||
super(Unit3D, self).__init__()
|
||||
|
||||
self._output_channels = output_channels
|
||||
self._kernel_shape = kernel_shape
|
||||
self._stride = stride
|
||||
self._use_batch_norm = use_batch_norm
|
||||
self._activation_fn = activation_fn
|
||||
self._use_bias = use_bias
|
||||
self.name = name
|
||||
self.padding = padding
|
||||
|
||||
self.conv3d = nn.Conv3d(in_channels=in_channels,
|
||||
out_channels=self._output_channels,
|
||||
kernel_size=self._kernel_shape,
|
||||
stride=self._stride,
|
||||
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
||||
bias=self._use_bias)
|
||||
|
||||
if self._use_batch_norm:
|
||||
self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
|
||||
|
||||
def compute_pad(self, dim, s):
|
||||
if s % self._stride[dim] == 0:
|
||||
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
||||
else:
|
||||
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# compute 'same' padding
|
||||
(batch, channel, t, h, w) = x.size()
|
||||
#print t,h,w
|
||||
out_t = np.ceil(float(t) / float(self._stride[0]))
|
||||
out_h = np.ceil(float(h) / float(self._stride[1]))
|
||||
out_w = np.ceil(float(w) / float(self._stride[2]))
|
||||
#print out_t, out_h, out_w
|
||||
pad_t = self.compute_pad(0, t)
|
||||
pad_h = self.compute_pad(1, h)
|
||||
pad_w = self.compute_pad(2, w)
|
||||
#print pad_t, pad_h, pad_w
|
||||
|
||||
pad_t_f = pad_t // 2
|
||||
pad_t_b = pad_t - pad_t_f
|
||||
pad_h_f = pad_h // 2
|
||||
pad_h_b = pad_h - pad_h_f
|
||||
pad_w_f = pad_w // 2
|
||||
pad_w_b = pad_w - pad_w_f
|
||||
|
||||
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
||||
#print x.size()
|
||||
#print pad
|
||||
x = F.pad(x, pad)
|
||||
#print x.size()
|
||||
|
||||
x = self.conv3d(x)
|
||||
if self._use_batch_norm:
|
||||
x = self.bn(x)
|
||||
if self._activation_fn is not None:
|
||||
x = self._activation_fn(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, name):
|
||||
super(InceptionModule, self).__init__()
|
||||
|
||||
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_0/Conv3d_0a_1x1')
|
||||
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_1/Conv3d_0a_1x1')
|
||||
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
|
||||
name=name+'/Branch_1/Conv3d_0b_3x3')
|
||||
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_2/Conv3d_0a_1x1')
|
||||
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
|
||||
name=name+'/Branch_2/Conv3d_0b_3x3')
|
||||
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
||||
stride=(1, 1, 1), padding=0)
|
||||
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_3/Conv3d_0b_1x1')
|
||||
self.name = name
|
||||
|
||||
def forward(self, x):
|
||||
b0 = self.b0(x)
|
||||
b1 = self.b1b(self.b1a(x))
|
||||
b2 = self.b2b(self.b2a(x))
|
||||
b3 = self.b3b(self.b3a(x))
|
||||
return torch.cat([b0,b1,b2,b3], dim=1)
|
||||
|
||||
|
||||
class InceptionI3d(nn.Module):
|
||||
"""Inception-v1 I3D architecture.
|
||||
The model is introduced in:
|
||||
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
||||
Joao Carreira, Andrew Zisserman
|
||||
https://arxiv.org/pdf/1705.07750v1.pdf.
|
||||
See also the Inception architecture, introduced in:
|
||||
Going deeper with convolutions
|
||||
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
||||
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
||||
http://arxiv.org/pdf/1409.4842v1.pdf.
|
||||
"""
|
||||
|
||||
# Endpoints of the model in order. During construction, all the endpoints up
|
||||
# to a designated `final_endpoint` are returned in a dictionary as the
|
||||
# second return value.
|
||||
VALID_ENDPOINTS = (
|
||||
'Conv3d_1a_7x7',
|
||||
'MaxPool3d_2a_3x3',
|
||||
'Conv3d_2b_1x1',
|
||||
'Conv3d_2c_3x3',
|
||||
'MaxPool3d_3a_3x3',
|
||||
'Mixed_3b',
|
||||
'Mixed_3c',
|
||||
'MaxPool3d_4a_3x3',
|
||||
'Mixed_4b',
|
||||
'Mixed_4c',
|
||||
'Mixed_4d',
|
||||
'Mixed_4e',
|
||||
'Mixed_4f',
|
||||
'MaxPool3d_5a_2x2',
|
||||
'Mixed_5b',
|
||||
'Mixed_5c',
|
||||
'Logits',
|
||||
'Predictions',
|
||||
)
|
||||
|
||||
def __init__(self, num_classes=400, spatial_squeeze=True,
|
||||
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
|
||||
"""Initializes I3D model instance.
|
||||
Args:
|
||||
num_classes: The number of outputs in the logit layer (default 400, which
|
||||
matches the Kinetics dataset).
|
||||
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
||||
before returning (default True).
|
||||
final_endpoint: The model contains many possible endpoints.
|
||||
`final_endpoint` specifies the last endpoint for the model to be built
|
||||
up to. In addition to the output at `final_endpoint`, all the outputs
|
||||
at endpoints up to `final_endpoint` will also be returned, in a
|
||||
dictionary. `final_endpoint` must be one of
|
||||
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
||||
name: A string (optional). The name of this module.
|
||||
Raises:
|
||||
ValueError: if `final_endpoint` is not recognized.
|
||||
"""
|
||||
|
||||
if final_endpoint not in self.VALID_ENDPOINTS:
|
||||
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
||||
|
||||
super(InceptionI3d, self).__init__()
|
||||
self._num_classes = num_classes
|
||||
self._spatial_squeeze = spatial_squeeze
|
||||
self._final_endpoint = final_endpoint
|
||||
self.logits = None
|
||||
|
||||
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
||||
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
|
||||
|
||||
self.end_points = {}
|
||||
end_point = 'Conv3d_1a_7x7'
|
||||
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
|
||||
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_2a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Conv3d_2b_1x1'
|
||||
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Conv3d_2c_3x3'
|
||||
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
|
||||
name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_3a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_3b'
|
||||
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_3c'
|
||||
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_4a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4b'
|
||||
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4c'
|
||||
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4d'
|
||||
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4e'
|
||||
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4f'
|
||||
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_5a_2x2'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_5b'
|
||||
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_5c'
|
||||
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Logits'
|
||||
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
|
||||
stride=(1, 1, 1))
|
||||
self.dropout = nn.Dropout(dropout_keep_prob)
|
||||
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
||||
kernel_shape=[1, 1, 1],
|
||||
padding=0,
|
||||
activation_fn=None,
|
||||
use_batch_norm=False,
|
||||
use_bias=True,
|
||||
name='logits')
|
||||
|
||||
self.build()
|
||||
|
||||
|
||||
def replace_logits(self, num_classes):
|
||||
self._num_classes = num_classes
|
||||
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
||||
kernel_shape=[1, 1, 1],
|
||||
padding=0,
|
||||
activation_fn=None,
|
||||
use_batch_norm=False,
|
||||
use_bias=True,
|
||||
name='logits')
|
||||
|
||||
|
||||
def build(self):
|
||||
for k in self.end_points.keys():
|
||||
self.add_module(k, self.end_points[k])
|
||||
|
||||
def forward(self, x):
|
||||
for end_point in self.VALID_ENDPOINTS:
|
||||
if end_point in self.end_points:
|
||||
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
||||
|
||||
x = self.logits(self.dropout(self.avg_pool(x)))
|
||||
if self._spatial_squeeze:
|
||||
logits = x.squeeze(3).squeeze(3)
|
||||
# logits is batch X time X classes, which is what we want to work with
|
||||
return logits
|
||||
|
||||
|
||||
def extract_features(self, x):
|
||||
for end_point in self.VALID_ENDPOINTS:
|
||||
if end_point in self.end_points:
|
||||
x = self._modules[end_point](x)
|
||||
return self.avg_pool(x)
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from videotransforms import (
|
||||
GroupScale, GroupCenterCrop, GroupNormalize, Stack
|
||||
)
|
||||
|
||||
from models.pytorch_i3d import InceptionI3d
|
||||
|
||||
from metrics import accuracy, AverageMeter
|
||||
|
||||
from dataset import I3DDataSet
|
||||
from default import _C as config
|
||||
from default import update_config
|
||||
|
||||
# to work with vscode debugger https://github.com/joblib/joblib/issues/864
|
||||
import multiprocessing
|
||||
multiprocessing.set_start_method('spawn', True)
|
||||
|
||||
|
||||
def load_model(modality, state_dict_file):
|
||||
|
||||
channels = 3 if modality == "RGB" else 2
|
||||
model = InceptionI3d(config.DATASET.NUM_CLASSES, in_channels=channels)
|
||||
state_dict = torch.load(state_dict_file)
|
||||
model.load_state_dict(state_dict)
|
||||
model = model.cuda()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test(model, test_loader, modality):
|
||||
|
||||
model.eval()
|
||||
|
||||
target_list = []
|
||||
predictions_list = []
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for step, (input, target) in enumerate(test_loader):
|
||||
target_list.append(target)
|
||||
input = input.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(input)
|
||||
output = torch.mean(output, dim=2)
|
||||
predictions_list.append(output)
|
||||
|
||||
if step % config.TEST.PRINT_FREQ == 0:
|
||||
print(('Step: [{0}/{1}]'.format(step, len(test_loader))))
|
||||
|
||||
targets = torch.cat(target_list)
|
||||
predictions = torch.cat(predictions_list)
|
||||
return targets, predictions
|
||||
|
||||
|
||||
def run(*options, cfg=None):
|
||||
|
||||
update_config(config, options=options, config_file=cfg)
|
||||
|
||||
torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(config.SEED)
|
||||
np.random.seed(seed=config.SEED)
|
||||
|
||||
# Setup Augmentation/Transformation pipeline
|
||||
input_size = config.TRAIN.INPUT_SIZE
|
||||
resize_range_min = config.TRAIN.RESIZE_MIN
|
||||
|
||||
# Data-parallel
|
||||
devices_lst = list(range(torch.cuda.device_count()))
|
||||
print("Devices {}".format(devices_lst))
|
||||
|
||||
if (config.TEST.MODALITY == "RGB") or (config.TEST.MODALITY == "combined"):
|
||||
|
||||
rgb_loader = torch.utils.data.DataLoader(
|
||||
I3DDataSet(
|
||||
data_root=config.DATASET.DIR,
|
||||
split=config.DATASET.SPLIT,
|
||||
modality="RGB",
|
||||
train_mode=False,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(config.TRAIN.RESIZE_MIN),
|
||||
GroupCenterCrop(config.TRAIN.INPUT_SIZE),
|
||||
GroupNormalize(modality="RGB"),
|
||||
Stack(),
|
||||
])
|
||||
),
|
||||
batch_size=config.TEST.BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=config.WORKERS,
|
||||
pin_memory=config.PIN_MEMORY
|
||||
)
|
||||
|
||||
rgb_model_file = config.TEST.MODEL_RGB
|
||||
if not os.path.exists(rgb_model_file):
|
||||
raise FileNotFoundError(rgb_model_file, " does not exist")
|
||||
rgb_model = load_model(modality="RGB", state_dict_file=rgb_model_file)
|
||||
|
||||
print("scoring with rgb model")
|
||||
targets, rgb_predictions = test(rgb_model, rgb_loader, "RGB")
|
||||
|
||||
del rgb_model
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
rgb_top1_accuracy = accuracy(rgb_predictions, targets, topk=(1, ))
|
||||
print("rgb top1 accuracy: ", rgb_top1_accuracy[0].cpu().numpy().tolist())
|
||||
|
||||
if (config.TEST.MODALITY == "flow") or (config.TEST.MODALITY == "combined"):
|
||||
|
||||
flow_loader = torch.utils.data.DataLoader(
|
||||
I3DDataSet(
|
||||
data_root=config.DATASET.DIR,
|
||||
split=config.DATASET.SPLIT,
|
||||
modality="flow",
|
||||
train_mode=False,
|
||||
sample_frames_at_test=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(config.TRAIN.RESIZE_MIN),
|
||||
GroupCenterCrop(config.TRAIN.INPUT_SIZE),
|
||||
GroupNormalize(modality="flow"),
|
||||
Stack(),
|
||||
])
|
||||
),
|
||||
batch_size=config.TEST.BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=config.WORKERS,
|
||||
pin_memory=config.PIN_MEMORY
|
||||
)
|
||||
|
||||
flow_model_file = config.TEST.MODEL_FLOW
|
||||
if not os.path.exists(flow_model_file):
|
||||
raise FileNotFoundError(flow_model_file, " does not exist")
|
||||
flow_model = load_model(modality="flow", state_dict_file=flow_model_file)
|
||||
|
||||
print("scoring with flow model")
|
||||
targets, flow_predictions = test(flow_model, flow_loader, "flow")
|
||||
|
||||
del flow_model
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
flow_top1_accuracy = accuracy(flow_predictions, targets, topk=(1, ))
|
||||
print("flow top1 accuracy: ", flow_top1_accuracy[0].cpu().numpy().tolist())
|
||||
|
||||
if config.TEST.MODALITY == "combined":
|
||||
predictions = torch.stack([rgb_predictions, flow_predictions])
|
||||
predictions_mean = torch.mean(predictions, dim=0)
|
||||
top1accuracy = accuracy(predictions_mean, targets, topk=(1, ))
|
||||
print("combined top1 accuracy: ", top1accuracy[0].cpu().numpy().tolist())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run)
|
|
@ -0,0 +1,278 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import numpy as np
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
from torch.autograd import Variable
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
from tensorboardX import SummaryWriter
|
||||
from default import _C as config
|
||||
from default import update_config
|
||||
|
||||
from videotransforms import (
|
||||
GroupRandomCrop, GroupRandomHorizontalFlip,
|
||||
GroupScale, GroupCenterCrop, GroupNormalize, Stack
|
||||
)
|
||||
from models.pytorch_i3d import InceptionI3d
|
||||
from metrics import accuracy, AverageMeter
|
||||
from dataset import I3DDataSet
|
||||
|
||||
|
||||
# to work with vscode debugger https://github.com/joblib/joblib/issues/864
|
||||
import multiprocessing
|
||||
multiprocessing.set_start_method('spawn', True)
|
||||
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, writer=None):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for step, (input, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
input = input.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(input)
|
||||
output = torch.mean(output, dim=2)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(output, target, topk=(1,5))
|
||||
losses.update(loss.item(), input.size(0))
|
||||
top1.update(prec1[0], input.size(0))
|
||||
top5.update(prec5[0], input.size(0))
|
||||
|
||||
loss = loss / config.TRAIN.GRAD_ACCUM_STEPS
|
||||
|
||||
loss.backward()
|
||||
|
||||
if step % config.TRAIN.GRAD_ACCUM_STEPS == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % config.TRAIN.PRINT_FREQ == 0:
|
||||
print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
epoch, step, len(train_loader), batch_time=batch_time,
|
||||
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'])))
|
||||
|
||||
if writer:
|
||||
writer.add_scalar('train/loss', losses.avg, epoch+1)
|
||||
writer.add_scalar('train/top1', top1.avg, epoch+1)
|
||||
writer.add_scalar('train/top5', top5.avg, epoch+1)
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion, epoch, writer=None):
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for step, (input, target) in enumerate(val_loader):
|
||||
input = input.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(input)
|
||||
output = torch.mean(output, dim=2)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(output, target, topk=(1,5))
|
||||
|
||||
losses.update(loss.item(), input.size(0))
|
||||
top1.update(prec1[0], input.size(0))
|
||||
top5.update(prec5[0], input.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % config.TEST.PRINT_FREQ == 0:
|
||||
print(('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
step, len(val_loader), batch_time=batch_time, loss=losses,
|
||||
top1=top1, top5=top5)))
|
||||
|
||||
print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
|
||||
.format(top1=top1, top5=top5, loss=losses)))
|
||||
|
||||
if writer:
|
||||
writer.add_scalar('val/loss', losses.avg, epoch+1)
|
||||
writer.add_scalar('val/top1', top1.avg, epoch+1)
|
||||
writer.add_scalar('val/top5', top5.avg, epoch+1)
|
||||
|
||||
return losses.avg
|
||||
|
||||
|
||||
def run(*options, cfg=None):
|
||||
"""Run training and validation of model
|
||||
|
||||
Notes:
|
||||
Options can be passed in via the options argument and loaded from the cfg file
|
||||
Options loaded from default.py will be overridden by options loaded from cfg file
|
||||
Options passed in through options argument will override option loaded from cfg file
|
||||
|
||||
Args:
|
||||
*options (str,int ,optional): Options used to overide what is loaded from the config.
|
||||
To see what options are available consult default.py
|
||||
cfg (str, optional): Location of config file to load. Defaults to None.
|
||||
"""
|
||||
update_config(config, options=options, config_file=cfg)
|
||||
|
||||
print("Training ", config.TRAIN.MODALITY, " model.")
|
||||
print("Batch size:", config.TRAIN.BATCH_SIZE, " Gradient accumulation steps:", config.TRAIN.GRAD_ACCUM_STEPS)
|
||||
|
||||
torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK
|
||||
|
||||
torch.manual_seed(config.SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(config.SEED)
|
||||
np.random.seed(seed=config.SEED)
|
||||
|
||||
# Log to tensorboard
|
||||
writer = SummaryWriter(log_dir=config.LOG_DIR)
|
||||
|
||||
# Setup dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
I3DDataSet(
|
||||
data_root=config.DATASET.DIR,
|
||||
split=config.DATASET.SPLIT,
|
||||
sample_frames=config.TRAIN.SAMPLE_FRAMES,
|
||||
modality=config.TRAIN.MODALITY,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(config.TRAIN.RESIZE_MIN),
|
||||
GroupRandomCrop(config.TRAIN.INPUT_SIZE),
|
||||
GroupRandomHorizontalFlip(),
|
||||
GroupNormalize(modality=config.TRAIN.MODALITY),
|
||||
Stack(),
|
||||
])
|
||||
),
|
||||
batch_size=config.TRAIN.BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=config.WORKERS,
|
||||
pin_memory=config.PIN_MEMORY
|
||||
)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
I3DDataSet(
|
||||
data_root=config.DATASET.DIR,
|
||||
split=config.DATASET.SPLIT,
|
||||
modality=config.TRAIN.MODALITY,
|
||||
train_mode=False,
|
||||
transform=torchvision.transforms.Compose([
|
||||
GroupScale(config.TRAIN.RESIZE_MIN),
|
||||
GroupCenterCrop(config.TRAIN.INPUT_SIZE),
|
||||
GroupNormalize(modality=config.TRAIN.MODALITY),
|
||||
Stack(),
|
||||
]),
|
||||
),
|
||||
batch_size=config.TEST.BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=config.WORKERS,
|
||||
pin_memory=config.PIN_MEMORY
|
||||
)
|
||||
|
||||
# Setup model
|
||||
if config.TRAIN.MODALITY == "RGB":
|
||||
channels = 3
|
||||
checkpoint = config.MODEL.PRETRAINED_RGB
|
||||
elif config.TRAIN.MODALITY == "flow":
|
||||
channels = 2
|
||||
checkpoint = config.MODEL.PRETRAINED_FLOW
|
||||
else:
|
||||
raise ValueError("Modality must be RGB or flow")
|
||||
|
||||
i3d_model = InceptionI3d(400, in_channels=channels)
|
||||
i3d_model.load_state_dict(torch.load(checkpoint))
|
||||
|
||||
# Replace final FC layer to match dataset
|
||||
i3d_model.replace_logits(config.DATASET.NUM_CLASSES)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
|
||||
optimizer = optim.SGD(
|
||||
i3d_model.parameters(),
|
||||
lr=0.1,
|
||||
momentum=0.9,
|
||||
weight_decay=0.0000001
|
||||
)
|
||||
|
||||
i3d_model = i3d_model.cuda()
|
||||
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
factor=0.1,
|
||||
patience=2,
|
||||
verbose=True,
|
||||
threshold=1e-4,
|
||||
min_lr=1e-4
|
||||
)
|
||||
|
||||
# Data-parallel
|
||||
devices_lst = list(range(torch.cuda.device_count()))
|
||||
print("Devices {}".format(devices_lst))
|
||||
if len(devices_lst) > 1:
|
||||
i3d_model = torch.nn.DataParallel(i3d_model)
|
||||
|
||||
if not os.path.exists(config.MODEL.CHECKPOINT_DIR):
|
||||
os.makedirs(config.MODEL.CHECKPOINT_DIR)
|
||||
|
||||
for epoch in range(config.TRAIN.MAX_EPOCHS):
|
||||
|
||||
train(train_loader,
|
||||
i3d_model,
|
||||
criterion,
|
||||
optimizer,
|
||||
epoch,
|
||||
writer
|
||||
)
|
||||
|
||||
if (epoch + 1) % config.TEST.EVAL_FREQ == 0 or epoch == config.TRAIN.MAX_EPOCHS - 1:
|
||||
val_loss = validate(val_loader, i3d_model, criterion, epoch, writer)
|
||||
scheduler.step(val_loss)
|
||||
torch.save(
|
||||
i3d_model.module.state_dict(),
|
||||
config.MODEL.CHECKPOINT_DIR+'/'+config.MODEL.NAME+'_split'+str(config.DATASET.SPLIT)+'_epoch'+str(epoch).zfill(3)+'.pt'
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run)
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Adapted from https://github.com/feiyunzhang/i3d-non-local-pytorch/blob/master/transforms.py
|
||||
|
||||
import torchvision
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import numbers
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class GroupScale(object):
|
||||
|
||||
def __init__(self, size, interpolation=Image.BILINEAR):
|
||||
self.worker = torchvision.transforms.Resize(size, interpolation)
|
||||
|
||||
def __call__(self, img_group):
|
||||
return [self.worker(img) for img in img_group]
|
||||
|
||||
|
||||
class GroupRandomCrop(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img_group):
|
||||
|
||||
w, h = img_group[0].size
|
||||
th, tw = self.size
|
||||
|
||||
out_images = list()
|
||||
|
||||
x1 = random.randint(0, w - tw)
|
||||
y1 = random.randint(0, h - th)
|
||||
|
||||
for img in img_group:
|
||||
assert(img.size[0] == w and img.size[1] == h)
|
||||
if w == tw and h == th:
|
||||
out_images.append(img)
|
||||
else:
|
||||
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
|
||||
return out_images
|
||||
|
||||
|
||||
class GroupCenterCrop(object):
|
||||
def __init__(self, size):
|
||||
self.worker = torchvision.transforms.CenterCrop(size)
|
||||
|
||||
def __call__(self, img_group):
|
||||
cropped_imgs = [self.worker(img) for img in img_group]
|
||||
return cropped_imgs
|
||||
|
||||
|
||||
class GroupRandomHorizontalFlip(object):
|
||||
|
||||
def __call__(self, img_group):
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
||||
return ret
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class GroupNormalize(object):
|
||||
|
||||
def __init__(self, modality, means=[0.485, 0.456, 0.406], stds=[0.229, 0.224, 0.225]):
|
||||
self.modality = modality
|
||||
self.means = means
|
||||
self.stds = stds
|
||||
self.tensor_worker = torchvision.transforms.ToTensor()
|
||||
self.norm_worker = torchvision.transforms.Normalize(mean=means, std=stds)
|
||||
|
||||
def __call__(self, img_group):
|
||||
if self.modality == "RGB":
|
||||
# Convert images to tensors in range [0, 1]
|
||||
img_tensors = [self.tensor_worker(img) for img in img_group]
|
||||
# Normalize to imagenet means and stds
|
||||
img_tensors = [self.norm_worker(img) for img in img_tensors]
|
||||
else:
|
||||
# Convert images to numpy arrays
|
||||
img_arrays = [np.asarray(img).transpose([2, 0, 1]) for img in img_group]
|
||||
# Scale to [-1, 1] and convert to tensor
|
||||
img_tensors = [torch.from_numpy((img / 255.) * 2 - 1) for img in img_arrays]
|
||||
return img_tensors
|
||||
|
||||
|
||||
class Stack(object):
|
||||
|
||||
def __call__(self, img_tensors):
|
||||
# Stack tensors and permute from D x C x H x W to C x D x H x W
|
||||
return torch.stack(img_tensors, dim=0).permute(1, 0, 2, 3).float()
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 8.5 MiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 8.3 MiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 221 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 459 KiB |
|
@ -64,6 +64,7 @@ steps:
|
|||
displayName: 'submit_azureml_pytest'
|
||||
|
||||
- task: PublishTestResults@2
|
||||
timeoutInMinutes: 360
|
||||
displayName: 'Publish Test Results **/test-*.xml'
|
||||
inputs:
|
||||
testResultsFiles: '**/test-*.xml'
|
||||
|
|
|
@ -58,6 +58,7 @@ steps:
|
|||
displayName: 'submit_azureml_pytest'
|
||||
|
||||
- task: PublishTestResults@2
|
||||
timeoutInMinutes: 360
|
||||
displayName: 'Publish Test Results **/test-*.xml'
|
||||
condition: always()
|
||||
inputs:
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import papermill as pm
|
||||
import pytest
|
||||
import scrapbook as sb
|
||||
|
||||
# Parameters
|
||||
KERNEL_NAME = "python3"
|
||||
OUTPUT_NOTEBOOK = "output.ipynb"
|
||||
|
||||
|
||||
@pytest.mark.notebooks
|
||||
@pytest.mark.linuxgpu
|
||||
def test_01_notebook_run(action_recognition_notebooks):
|
||||
epochs = 4
|
||||
notebook_path = action_recognition_notebooks["01"]
|
||||
pm.execute_notebook(
|
||||
notebook_path,
|
||||
OUTPUT_NOTEBOOK,
|
||||
parameters=dict(PM_VERSION=pm.__version__, EPOCHS=epochs),
|
||||
kernel_name=KERNEL_NAME,
|
||||
)
|
||||
|
||||
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
|
||||
|
||||
vid_pred_accuracy = nb_output.scraps["vid_pred_accuracy"].data
|
||||
clip_pred_accuracy = nb_output.scraps["clip_pred_accuracy"].data
|
||||
|
||||
assert vid_pred_accuracy > 0.3
|
||||
assert clip_pred_accuracy > 0.3
|
|
@ -24,8 +24,7 @@ def test_00_notebook_run(action_recognition_notebooks):
|
|||
notebook_path,
|
||||
OUTPUT_NOTEBOOK,
|
||||
parameters=dict(
|
||||
PM_VERSION=pm.__version__,
|
||||
sample_video_url=Urls.webcam_vid_low_res
|
||||
PM_VERSION=pm.__version__, sample_video_url=Urls.webcam_vid_low_res
|
||||
),
|
||||
kernel_name=KERNEL_NAME,
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче