Add r2+1d from staging into master
r2p1d model fine-tuning scripts and examples (#418)
This commit is contained in:
Коммит
184cc9722e
|
@ -116,6 +116,7 @@ output.ipynb
|
|||
# don't save any data
|
||||
classification/data/*
|
||||
data/
|
||||
!contrib/action_recognition/r2p1d/**
|
||||
|
||||
# don't save .swp files
|
||||
*.swp
|
||||
|
@ -123,9 +124,6 @@ data/
|
|||
# don't save .csv files
|
||||
*.csv
|
||||
|
||||
# don't save data dir
|
||||
data
|
||||
|
||||
# don't save pickles
|
||||
*.pkl
|
||||
|
||||
|
|
|
@ -1,31 +1,69 @@
|
|||
# Action Recognition
|
||||
|
||||
This is a place holder. Content will follow soon.
|
||||
Action recognition (also often called activity recognition) consists of classifying different actions from a sequence
|
||||
of frames in videos.
|
||||
|
||||
![](./media/action_recognition.gif)
|
||||
This directory contains example projects for building video-based action recognition systems.
|
||||
Our goal is to enable users to easily and quickly train high-accuracy action recognition models with fast inference speed.
|
||||
|
||||
![](./media/action_recognition.gif "Example of action recognition")
|
||||
|
||||
*Example of action recognition*
|
||||
|
||||
## Overview
|
||||
|
||||
| Folders | Description |
|
||||
Currently, we provide two state of the art model implementations, Two-Stream [Inflated 3D ConvNet, I3D](https://arxiv.org/pdf/1705.07750.pdf)
|
||||
and RGB [ResNets with (2+1)D convolutions, R(2+1)D](https://arxiv.org/abs/1711.11248)
|
||||
along with their example notebooks for fine-tuning on [HMDB-51 dataset](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/).
|
||||
More details about the models can be found in [Models](#models) section below.
|
||||
|
||||
Each project describes independent SETUP process with a separate conda environment under its directory.
|
||||
|
||||
We recommend to use R(2+1)D implementation for its competitive accuracy (see below [comparison table](#comparison)) with much faster inference speed as well as less-dependencies on other packages.
|
||||
We also provide the Webcam stream notebook for R(2+1)D to test real-time inference.
|
||||
|
||||
Nevertheless, I3D implementation gives a good example of utilizing the two-stream approach for action recognition so that those who want to benchmark and/or tryout different approaches can use.
|
||||
|
||||
|
||||
## Projects
|
||||
|
||||
| Directory | Description |
|
||||
| -------- | ----------- |
|
||||
| [i3d](i3d) | Scripts for fine-tuning a pre-trained Two-Stream Inflated 3D ConvNet (I3D) model on the HMDB-51 dataset
|
||||
| [r2p1d](r2p1d) | Scripts for fine-tuning a pre-trained R(2+1)D model on HMDB-51 dataset
|
||||
| [i3d](i3d) | Scripts for fine-tuning a pre-trained I3D model on HMDB-51 dataset
|
||||
| [video_annotation](video_annotation) | Instructions and helper functions to annotate the start and end position of actions in video footage|
|
||||
|
||||
## Functionality
|
||||
|
||||
In [i3d](i3d) we show how to fine-tune a Two-Stream Inflated 3D ConvNet (I3D) model. This model was introduced in \[[1](https://arxiv.org/pdf/1705.07750.pdf)\] and achieved state-of-the-art in action classification on the HMDB-51 and UCF-101 datasets. The paper demonstrated the effectiveness of pre-training action recognition models on large datasets - in this case the Kinetics Human Action Video dataset consisting of 306k examples and 400 classes. We provide code for replicating the results of this paper on HMDB-51. We use models pre-trained on Kinetics from [https://github.com/piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d). Evaluating the model on the test set of the HMDB-51 dataset (split 1) using [i3d/test.py](i3d/test.py) should yield the following results:
|
||||
### Models
|
||||
|
||||
| Model | Paper top 1 accuracy (average over 3 splits) | Our models top 1 accuracy (split 1 only) |
|
||||
The R(2+1)D model was presented in \[2\] where the authors pretrained the model on [Kinetics400](https://arxiv.org/abs/1705.06950) and produced decent performance close to state of the art on the HMDB-51 dataset.
|
||||
In [r2p1d](r2p1d) we demonstrate fine-tuning R(2+1)D model [pretrained on 65 million videos](https://arxiv.org/abs/1905.00561).
|
||||
We use the pretrained weight from [https://github.com/moabitcoin/ig65m-pytorch](https://github.com/moabitcoin/ig65m-pytorch).
|
||||
|
||||
In [i3d](i3d) we show how to fine-tune I3D model. This model was introduced in \[[1](https://arxiv.org/pdf/1705.07750.pdf)\]
|
||||
and achieved state of the art in action classification on the HMDB-51 and UCF-101 datasets.
|
||||
Here, we use models pre-trained on Kinetics from [https://github.com/piergiaj/pytorch-i3d](https://github.com/piergiaj/pytorch-i3d).
|
||||
|
||||
The following table shows the comparison between the reported performance in the original papers and our results on HMDB-51 dataset.
|
||||
Please note that the accuracies from the papers are averages over 3 splits, while ours are based on the split 1 only.
|
||||
Also, the original R(2+1)D paper used the model pretrained on Kinetics400 but we used the one pretrained on the 65 million videos which explains the higher accuracy (74.5% vs 79.8%).
|
||||
|
||||
|
||||
*Comparison on HMDB-51*
|
||||
|
||||
<a id="comparison"></a>
|
||||
|
||||
| Model | Reported in the paper | Our results |
|
||||
| ------- | -------| ------- |
|
||||
| RGB | 74.8 | 73.7 |
|
||||
| Optical flow | 77.1 | 77.5 |
|
||||
| Two-Stream | 80.7 | 81.2 |
|
||||
| R(2+1)D RGB | 74.5 | 79.8 |
|
||||
| I3D RGB | 74.8 | 73.7 |
|
||||
| I3D Optical flow | 77.1 | 77.5 |
|
||||
| I3D Two-Stream | 80.7 | 81.2 |
|
||||
|
||||
|
||||
### Annotation
|
||||
In order to train an action recognition model for a specific task, annotated training data from the relevant domain is needed. In [video_annotation](video_annotation), we provide tips and examples for how to use a best-in-class video annotation tool ([VGG Image Annotator](http://www.robots.ox.ac.uk/~vgg/software/via/)) to label the start and end positions of actions in videos.
|
||||
|
||||
## State-of-the-art
|
||||
## State of the art
|
||||
|
||||
In the tables below, we list datasets which are commonly used and also give an overview of the state-of-the-art. Note that the information below is reasonably exhaustive and should cover most major publications until 2018. Expect however some level of incompleteness and slight incorrectness (e.g. publication year being off by plus/minus 1 year due) since the tables below were mainly compiled to give a high-level picture of where the field is and how it evolved over the last years.
|
||||
|
||||
|
@ -34,7 +72,7 @@ Recommended reading:
|
|||
- [ActionRecognition.net](http://actionrecognition.net/files/dset.php) for the latest state-of-the-art accuracies on popular research benchmark datasets.
|
||||
- All papers highlighted in yellow in the publications table below.
|
||||
|
||||
Popular datasets:
|
||||
*Popular datasets*
|
||||
|
||||
| Name | Year | Number of classes | #Clips | Average length per video | Notes |
|
||||
| ----- | ----- | ----------------- | ------- | ------------------------- | ----------- |
|
||||
|
@ -54,13 +92,16 @@ Popular datasets:
|
|||
|Youtube-8M Segments| 2019| 1000| 237k| 5sec| Used for localization Kaggle challenge. Think focuses on objects, not actions.|
|
||||
|
||||
|
||||
|
||||
Popular publications, with recommended papers to read highlighted in yellow:
|
||||
<img align="center" src="./media/publications.png"/>
|
||||
|
||||
<img align="center" src="./media/publications.png" />
|
||||
|
||||
|
||||
Most pulications focus on accuracy rather than on inferencing speed. The paper "Representation Flow for Action Recognition" is a noteworthy exception with this figure:
|
||||
<img align="center" src="./media/inference_speeds.png" width = "500"/>
|
||||
Most publications focus on accuracy rather than on inferencing speed. The paper "Representation Flow for Action Recognition" is a noteworthy exception with this figure:
|
||||
|
||||
<img align="center" src="./media/inference_speeds.png" width = "500" />
|
||||
|
||||
\[1\] J. Carreira and A. Zisserman. Quo vadis, action recognition?
|
||||
a new model and the kinetics dataset. In CVPR, 2017.
|
||||
a new model and the kinetics dataset. In CVPR, 2017.
|
||||
|
||||
\[2\] D. Tran, et al. A Closer Look at Spatiotemporal Convolutions for Action Recognition. arXiv:1711.11248 \[cs.CV\], 2017.
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
|
||||
# Data and models
|
||||
data/hmdb51/*
|
||||
!data/hmdb51/hmdb51_vid_*.txt
|
||||
data/kinetics400/*
|
||||
!data/kinetics400/label_map.txt
|
||||
*checkpoints/
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,59 @@
|
|||
# Human Action Recognition with R(2+1)D Model
|
||||
|
||||
R(2+1)D model architecture was originally presented in "A Closer Look at Spatiotemporal Convolutions for Action Recognition (2017)" paper from Facebook AI group.
|
||||
|
||||
<img src="model_arch.jpg" width="300" height="300" />
|
||||
|
||||
This project provides utility scripts to fine-tune the model and examples notebooks as follows:
|
||||
|
||||
| Notebook | Description |
|
||||
| --- | --- |
|
||||
| [00_webcam](00_webcam.ipynb) | A real-time inference example on Webcam stream |
|
||||
| [01_training_introduction](01_training_introduction.ipynb) | An example of training R(2+1)D model on HMDB-51 dataset |
|
||||
| [02_video_transformation](02_video_transformation.ipynb) | Examples of video transformations |
|
||||
|
||||
Specifically, we use the model pre-trained on 65 million social media videos (IG) presented in "[Large-scale weakly-supervised pre-training for video action recognition (2019)](https://arxiv.org/abs/1905.00561)" paper.
|
||||
|
||||
*Note: The official pretrained model weights can be found from [https://github.com/facebookresearch/vmz](https://github.com/facebookresearch/vmz) which are based on caffe2.
|
||||
In this repository, we use PyTorch-converted weights from [https://github.com/moabitcoin/ig65m-pytorch](https://github.com/moabitcoin/ig65m-pytorch).*
|
||||
|
||||
|
||||
## Prerequisite
|
||||
* Linux machine - We strongly recommend to use GPU machine to run the scripts and notebooks in this project smoothly (preferably [Azure NCsV3 series VMs](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes-gpu#ncv3-series)).
|
||||
* To use GPUs, **CUDA toolkit v10.1** is required. Details about the CUDA installation can be found [here](https://developer.nvidia.com/cuda-downloads). Once the installation is completed, you may need to reboot the VM.
|
||||
|
||||
|
||||
## Installation
|
||||
1. Setup conda environment
|
||||
`conda env create -f environment.yml`
|
||||
|
||||
1. Activate the environment
|
||||
`conda activate r2p1d`
|
||||
|
||||
1. Install jupyter kernel
|
||||
`python -m ipykernel install --user --name r2p1d`
|
||||
|
||||
### (Optional) Mixed-precision training
|
||||
* To use mixed-precision training via [NVIDIA-apex](https://github.com/NVIDIA/apex),
|
||||
```
|
||||
$ git clone https://github.com/NVIDIA/apex
|
||||
$ cd apex
|
||||
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
```
|
||||
|
||||
|
||||
### WebCam tunneling
|
||||
To run the model on remote GPU VM while using a local machine for WebCam streaming,
|
||||
|
||||
1. Open a terminal window that supports `ssh` from the local machine (e.g. [git-bash for Windows](https://gitforwindows.org/)).
|
||||
|
||||
1. Run following commandPort forward as follows (assuming Jupyter notebook will be running on the port 8888 on the VM)
|
||||
`ssh your-vm-address -L 8888:localhost:8888`
|
||||
|
||||
1. Clone this repository from the VM and install the conda environment and Jupyter kernel.
|
||||
|
||||
1. Start Jupyter notebook from the VM without starting a browser.
|
||||
`jupyter notebook --no-browser`
|
||||
You can also set Jupyter configuration to not start browser by default.
|
||||
|
||||
1. Copy the notebook address showing at the terminal and paste it to the browser on the local machine to open it.
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,400 @@
|
|||
abseiling
|
||||
air drumming
|
||||
answering questions
|
||||
applauding
|
||||
applying cream
|
||||
archery
|
||||
arm wrestling
|
||||
arranging flowers
|
||||
assembling computer
|
||||
auctioning
|
||||
baby waking up
|
||||
baking cookies
|
||||
balloon blowing
|
||||
bandaging
|
||||
barbequing
|
||||
bartending
|
||||
beatboxing
|
||||
bee keeping
|
||||
belly dancing
|
||||
bench pressing
|
||||
bending back
|
||||
bending metal
|
||||
biking through snow
|
||||
blasting sand
|
||||
blowing glass
|
||||
blowing leaves
|
||||
blowing nose
|
||||
blowing out candles
|
||||
bobsledding
|
||||
bookbinding
|
||||
bouncing on trampoline
|
||||
bowling
|
||||
braiding hair
|
||||
breading or breadcrumbing
|
||||
breakdancing
|
||||
brush painting
|
||||
brushing hair
|
||||
brushing teeth
|
||||
building cabinet
|
||||
building shed
|
||||
bungee jumping
|
||||
busking
|
||||
canoeing or kayaking
|
||||
capoeira
|
||||
carrying baby
|
||||
cartwheeling
|
||||
carving pumpkin
|
||||
catching fish
|
||||
catching or throwing baseball
|
||||
catching or throwing frisbee
|
||||
catching or throwing softball
|
||||
celebrating
|
||||
changing oil
|
||||
changing wheel
|
||||
checking tires
|
||||
cheerleading
|
||||
chopping wood
|
||||
clapping
|
||||
clay pottery making
|
||||
clean and jerk
|
||||
cleaning floor
|
||||
cleaning gutters
|
||||
cleaning pool
|
||||
cleaning shoes
|
||||
cleaning toilet
|
||||
cleaning windows
|
||||
climbing a rope
|
||||
climbing ladder
|
||||
climbing tree
|
||||
contact juggling
|
||||
cooking chicken
|
||||
cooking egg
|
||||
cooking on campfire
|
||||
cooking sausages
|
||||
counting money
|
||||
country line dancing
|
||||
cracking neck
|
||||
crawling baby
|
||||
crossing river
|
||||
crying
|
||||
curling hair
|
||||
cutting nails
|
||||
cutting pineapple
|
||||
cutting watermelon
|
||||
dancing ballet
|
||||
dancing charleston
|
||||
dancing gangnam style
|
||||
dancing macarena
|
||||
deadlifting
|
||||
decorating the christmas tree
|
||||
digging
|
||||
dining
|
||||
disc golfing
|
||||
diving cliff
|
||||
dodgeball
|
||||
doing aerobics
|
||||
doing laundry
|
||||
doing nails
|
||||
drawing
|
||||
dribbling basketball
|
||||
drinking
|
||||
drinking beer
|
||||
drinking shots
|
||||
driving car
|
||||
driving tractor
|
||||
drop kicking
|
||||
drumming fingers
|
||||
dunking basketball
|
||||
dying hair
|
||||
eating burger
|
||||
eating cake
|
||||
eating carrots
|
||||
eating chips
|
||||
eating doughnuts
|
||||
eating hotdog
|
||||
eating ice cream
|
||||
eating spaghetti
|
||||
eating watermelon
|
||||
egg hunting
|
||||
exercising arm
|
||||
exercising with an exercise ball
|
||||
extinguishing fire
|
||||
faceplanting
|
||||
feeding birds
|
||||
feeding fish
|
||||
feeding goats
|
||||
filling eyebrows
|
||||
finger snapping
|
||||
fixing hair
|
||||
flipping pancake
|
||||
flying kite
|
||||
folding clothes
|
||||
folding napkins
|
||||
folding paper
|
||||
front raises
|
||||
frying vegetables
|
||||
garbage collecting
|
||||
gargling
|
||||
getting a haircut
|
||||
getting a tattoo
|
||||
giving or receiving award
|
||||
golf chipping
|
||||
golf driving
|
||||
golf putting
|
||||
grinding meat
|
||||
grooming dog
|
||||
grooming horse
|
||||
gymnastics tumbling
|
||||
hammer throw
|
||||
headbanging
|
||||
headbutting
|
||||
high jump
|
||||
high kick
|
||||
hitting baseball
|
||||
hockey stop
|
||||
holding snake
|
||||
hopscotch
|
||||
hoverboarding
|
||||
hugging
|
||||
hula hooping
|
||||
hurdling
|
||||
hurling (sport)
|
||||
ice climbing
|
||||
ice fishing
|
||||
ice skating
|
||||
ironing
|
||||
javelin throw
|
||||
jetskiing
|
||||
jogging
|
||||
juggling balls
|
||||
juggling fire
|
||||
juggling soccer ball
|
||||
jumping into pool
|
||||
jumpstyle dancing
|
||||
kicking field goal
|
||||
kicking soccer ball
|
||||
kissing
|
||||
kitesurfing
|
||||
knitting
|
||||
krumping
|
||||
laughing
|
||||
laying bricks
|
||||
long jump
|
||||
lunge
|
||||
making a cake
|
||||
making a sandwich
|
||||
making bed
|
||||
making jewelry
|
||||
making pizza
|
||||
making snowman
|
||||
making sushi
|
||||
making tea
|
||||
marching
|
||||
massaging back
|
||||
massaging feet
|
||||
massaging legs
|
||||
massaging person's head
|
||||
milking cow
|
||||
mopping floor
|
||||
motorcycling
|
||||
moving furniture
|
||||
mowing lawn
|
||||
news anchoring
|
||||
opening bottle
|
||||
opening present
|
||||
paragliding
|
||||
parasailing
|
||||
parkour
|
||||
passing American football (in game)
|
||||
passing American football (not in game)
|
||||
peeling apples
|
||||
peeling potatoes
|
||||
petting animal (not cat)
|
||||
petting cat
|
||||
picking fruit
|
||||
planting trees
|
||||
plastering
|
||||
playing accordion
|
||||
playing badminton
|
||||
playing bagpipes
|
||||
playing basketball
|
||||
playing bass guitar
|
||||
playing cards
|
||||
playing cello
|
||||
playing chess
|
||||
playing clarinet
|
||||
playing controller
|
||||
playing cricket
|
||||
playing cymbals
|
||||
playing didgeridoo
|
||||
playing drums
|
||||
playing flute
|
||||
playing guitar
|
||||
playing harmonica
|
||||
playing harp
|
||||
playing ice hockey
|
||||
playing keyboard
|
||||
playing kickball
|
||||
playing monopoly
|
||||
playing organ
|
||||
playing paintball
|
||||
playing piano
|
||||
playing poker
|
||||
playing recorder
|
||||
playing saxophone
|
||||
playing squash or racquetball
|
||||
playing tennis
|
||||
playing trombone
|
||||
playing trumpet
|
||||
playing ukulele
|
||||
playing violin
|
||||
playing volleyball
|
||||
playing xylophone
|
||||
pole vault
|
||||
presenting weather forecast
|
||||
pull ups
|
||||
pumping fist
|
||||
pumping gas
|
||||
punching bag
|
||||
punching person (boxing)
|
||||
push up
|
||||
pushing car
|
||||
pushing cart
|
||||
pushing wheelchair
|
||||
reading book
|
||||
reading newspaper
|
||||
recording music
|
||||
riding a bike
|
||||
riding camel
|
||||
riding elephant
|
||||
riding mechanical bull
|
||||
riding mountain bike
|
||||
riding mule
|
||||
riding or walking with horse
|
||||
riding scooter
|
||||
riding unicycle
|
||||
ripping paper
|
||||
robot dancing
|
||||
rock climbing
|
||||
rock scissors paper
|
||||
roller skating
|
||||
running on treadmill
|
||||
sailing
|
||||
salsa dancing
|
||||
sanding floor
|
||||
scrambling eggs
|
||||
scuba diving
|
||||
setting table
|
||||
shaking hands
|
||||
shaking head
|
||||
sharpening knives
|
||||
sharpening pencil
|
||||
shaving head
|
||||
shaving legs
|
||||
shearing sheep
|
||||
shining shoes
|
||||
shooting basketball
|
||||
shooting goal (soccer)
|
||||
shot put
|
||||
shoveling snow
|
||||
shredding paper
|
||||
shuffling cards
|
||||
side kick
|
||||
sign language interpreting
|
||||
singing
|
||||
situp
|
||||
skateboarding
|
||||
ski jumping
|
||||
skiing (not slalom or crosscountry)
|
||||
skiing crosscountry
|
||||
skiing slalom
|
||||
skipping rope
|
||||
skydiving
|
||||
slacklining
|
||||
slapping
|
||||
sled dog racing
|
||||
smoking
|
||||
smoking hookah
|
||||
snatch weight lifting
|
||||
sneezing
|
||||
sniffing
|
||||
snorkeling
|
||||
snowboarding
|
||||
snowkiting
|
||||
snowmobiling
|
||||
somersaulting
|
||||
spinning poi
|
||||
spray painting
|
||||
spraying
|
||||
springboard diving
|
||||
squat
|
||||
sticking tongue out
|
||||
stomping grapes
|
||||
stretching arm
|
||||
stretching leg
|
||||
strumming guitar
|
||||
surfing crowd
|
||||
surfing water
|
||||
sweeping floor
|
||||
swimming backstroke
|
||||
swimming breast stroke
|
||||
swimming butterfly stroke
|
||||
swing dancing
|
||||
swinging legs
|
||||
swinging on something
|
||||
sword fighting
|
||||
tai chi
|
||||
taking a shower
|
||||
tango dancing
|
||||
tap dancing
|
||||
tapping guitar
|
||||
tapping pen
|
||||
tasting beer
|
||||
tasting food
|
||||
testifying
|
||||
texting
|
||||
throwing axe
|
||||
throwing ball
|
||||
throwing discus
|
||||
tickling
|
||||
tobogganing
|
||||
tossing coin
|
||||
tossing salad
|
||||
training dog
|
||||
trapezing
|
||||
trimming or shaving beard
|
||||
trimming trees
|
||||
triple jump
|
||||
tying bow tie
|
||||
tying knot (not on a tie)
|
||||
tying tie
|
||||
unboxing
|
||||
unloading truck
|
||||
using computer
|
||||
using remote controller (not gaming)
|
||||
using segway
|
||||
vault
|
||||
waiting in line
|
||||
walking the dog
|
||||
washing dishes
|
||||
washing feet
|
||||
washing hair
|
||||
washing hands
|
||||
water skiing
|
||||
water sliding
|
||||
watering plants
|
||||
waxing back
|
||||
waxing chest
|
||||
waxing eyebrows
|
||||
waxing legs
|
||||
weaving basket
|
||||
welding
|
||||
whistling
|
||||
windsurfing
|
||||
wrapping present
|
||||
wrestling
|
||||
writing
|
||||
yawning
|
||||
yoga
|
||||
zumba
|
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
drinking 100
|
|
@ -0,0 +1,30 @@
|
|||
name: r2p1d
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
- pytorch
|
||||
dependencies:
|
||||
- cudatoolkit>=10.1
|
||||
- pandas
|
||||
- numpy==1.17.2
|
||||
- python==3.6.8
|
||||
- pytorch>=1.2.0
|
||||
- torchvision>=0.4.0
|
||||
- ipykernel>=4.6.1
|
||||
- jupyter>=1.0.0
|
||||
- jupyter_contrib_nbextensions
|
||||
- jupyter_nbextensions_configurator
|
||||
- pytest>=3.6.4
|
||||
- scikit-learn>=0.19.1
|
||||
- pip>=19.0.3
|
||||
- pip:
|
||||
- azureml-sdk[notebooks]
|
||||
- azureml-dataprep[pandas]
|
||||
- decord
|
||||
- ipywebrtc
|
||||
- matplotlib
|
||||
- einops==0.1.0
|
||||
- pillow==6.2.0
|
||||
- six==1.12.0
|
||||
- tqdm==4.36.1
|
||||
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 45 KiB |
|
@ -0,0 +1,16 @@
|
|||
from vu.utils.common import Config, system_info
|
||||
|
||||
|
||||
def test_config():
|
||||
cfg = Config({'a': 1}, b=2, c=3)
|
||||
assert cfg.a == 1 and cfg.b == 2 and cfg.c == 3
|
||||
|
||||
cfg = Config({'a': 1, 'b': 2})
|
||||
assert cfg.a == 1 and cfg.b == 2
|
||||
|
||||
cfg2 = Config(cfg)
|
||||
assert cfg2.a == cfg.a and cfg2.b == cfg.b
|
||||
|
||||
|
||||
def test_system_info():
|
||||
system_info()
|
|
@ -0,0 +1,302 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from urllib.request import urlretrieve
|
||||
import warnings
|
||||
|
||||
import decord
|
||||
from einops.layers.torch import Rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from numpy.random import randint
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from .utils import transforms_video as transforms
|
||||
from .utils.functional_video import denormalize
|
||||
|
||||
|
||||
DEFAULT_MEAN = (0.43216, 0.394666, 0.37645)
|
||||
DEFAULT_STD = (0.22803, 0.22145, 0.216989)
|
||||
|
||||
|
||||
class _DatasetSpec:
|
||||
def __init__(self, label_url, root, num_classes):
|
||||
self.label_url = label_url
|
||||
self.root = root
|
||||
self.num_classes = num_classes
|
||||
self._class_names = None
|
||||
|
||||
@property
|
||||
def class_names(self):
|
||||
if self._class_names is None:
|
||||
label_filepath = os.path.join(self.root, "label_map.txt")
|
||||
if not os.path.isfile(label_filepath):
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
urlretrieve(self.label_url, label_filepath)
|
||||
with open(label_filepath) as f:
|
||||
self._class_names = [l.strip() for l in f]
|
||||
assert len(self._class_names) == self.num_classes
|
||||
|
||||
return self._class_names
|
||||
|
||||
|
||||
KINETICS = _DatasetSpec(
|
||||
"https://github.com/microsoft/ComputerVision/files/3746975/kinetics400_lable_map.txt",
|
||||
os.path.join("data", "kinetics400"),
|
||||
400
|
||||
)
|
||||
|
||||
HMDB51 = _DatasetSpec(
|
||||
"https://github.com/microsoft/ComputerVision/files/3746963/hmdb51_label_map.txt",
|
||||
os.path.join("data", "hmdb51"),
|
||||
51
|
||||
)
|
||||
|
||||
|
||||
class VideoRecord(object):
|
||||
def __init__(self, row):
|
||||
self._data = row
|
||||
self._num_frames = -1
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._data[0]
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
if self._num_frames == -1:
|
||||
self._num_frames = int(len([x for x in Path(self._data[0]).glob('img_*')]) - 1)
|
||||
return self._num_frames
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return int(self._data[1])
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
"""
|
||||
Args:
|
||||
split_file (str): Annotation file containing video filenames and labels.
|
||||
video_dir (str): Videos directory.
|
||||
num_segments (int): Number of clips to sample from each video.
|
||||
sample_length (int): Number of consecutive frames to sample from a video (i.e. clip length).
|
||||
sample_step (int): Sampling step.
|
||||
input_size (int or tuple): Model input image size.
|
||||
im_scale (int or tuple): Resize target size.
|
||||
resize_keep_ratio (bool): If True, keep the original ratio when resizing.
|
||||
mean (tuple): Normalization mean.
|
||||
std (tuple): Normalization std.
|
||||
random_shift (bool): Random temporal shift when sample a clip.
|
||||
temporal_jitter (bool): Randomly skip frames when sampling each frames.
|
||||
flip_ratio (float): Horizontal flip ratio.
|
||||
random_crop (bool): If False, do center-crop.
|
||||
random_crop_scales (tuple): Range of size of the origin size random cropped.
|
||||
video_ext (str): Video file extension.
|
||||
warning (bool): On or off warning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
split_file,
|
||||
video_dir,
|
||||
num_segments=1,
|
||||
sample_length=8,
|
||||
sample_step=1,
|
||||
input_size=112,
|
||||
im_scale=128,
|
||||
resize_keep_ratio=True,
|
||||
mean=DEFAULT_MEAN,
|
||||
std=DEFAULT_STD,
|
||||
random_shift=False,
|
||||
temporal_jitter=False,
|
||||
flip_ratio=0.5,
|
||||
random_crop=False,
|
||||
random_crop_scales=(0.6, 1.0),
|
||||
video_ext="mp4",
|
||||
warning=False,
|
||||
):
|
||||
# TODO maybe check wrong arguments to early failure
|
||||
assert sample_step > 0
|
||||
assert num_segments > 0
|
||||
|
||||
self.video_dir = video_dir
|
||||
self.video_records = [
|
||||
VideoRecord(x.strip().split(" ")) for x in open(split_file)
|
||||
]
|
||||
|
||||
self.num_segments = num_segments
|
||||
self.sample_length = sample_length
|
||||
self.sample_step = sample_step
|
||||
self.presample_length = sample_length * sample_step
|
||||
|
||||
# Temporal noise
|
||||
self.random_shift = random_shift
|
||||
self.temporal_jitter = temporal_jitter
|
||||
|
||||
# Video transforms
|
||||
# 1. resize
|
||||
trfms = [
|
||||
transforms.ToTensorVideo(),
|
||||
transforms.ResizeVideo(im_scale, resize_keep_ratio),
|
||||
]
|
||||
# 2. crop
|
||||
if random_crop:
|
||||
if random_crop_scales is not None:
|
||||
crop = transforms.RandomResizedCropVideo(input_size, random_crop_scales)
|
||||
else:
|
||||
crop = transforms.RandomCropVideo(input_size)
|
||||
else:
|
||||
crop = transforms.CenterCropVideo(input_size)
|
||||
trfms.append(crop)
|
||||
# 3. flip
|
||||
trfms.append(transforms.RandomHorizontalFlipVideo(flip_ratio))
|
||||
# 4. normalize
|
||||
trfms.append(transforms.NormalizeVideo(mean, std))
|
||||
self.transforms = Compose(trfms)
|
||||
self.video_ext = video_ext
|
||||
self.warning = warning
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_records)
|
||||
|
||||
def _sample_indices(self, record):
|
||||
"""
|
||||
Args:
|
||||
record (VideoRecord): A video record.
|
||||
Return:
|
||||
list: Segment offsets (start indices)
|
||||
"""
|
||||
if record.num_frames > self.presample_length:
|
||||
if self.random_shift:
|
||||
# Random sample
|
||||
offsets = np.sort(
|
||||
randint(
|
||||
record.num_frames - self.presample_length + 1,
|
||||
size=self.num_segments,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Uniform sample
|
||||
distance = (record.num_frames - self.presample_length + 1) / self.num_segments
|
||||
offsets = np.array(
|
||||
[int(distance / 2.0 + distance * x) for x in range(self.num_segments)]
|
||||
)
|
||||
else:
|
||||
if self.warning:
|
||||
warnings.warn(
|
||||
"num_segments and/or sample_length > num_frames in {}".format(
|
||||
record.path
|
||||
)
|
||||
)
|
||||
offsets = np.zeros((self.num_segments,), dtype=int)
|
||||
|
||||
return offsets
|
||||
|
||||
def _get_frames(self, video_reader, offset):
|
||||
clip = list()
|
||||
|
||||
# decord.seek() seems to have a bug. use seek_accurate().
|
||||
video_reader.seek_accurate(offset)
|
||||
# first frame
|
||||
clip.append(video_reader.next().asnumpy())
|
||||
# remaining frames
|
||||
try:
|
||||
if self.temporal_jitter:
|
||||
for i in range(self.sample_length - 1):
|
||||
step = randint(self.sample_step + 1)
|
||||
if step == 0:
|
||||
clip.append(clip[-1].copy())
|
||||
else:
|
||||
if step > 1:
|
||||
video_reader.skip_frames(step - 1)
|
||||
cur_frame = video_reader.next().asnumpy()
|
||||
if len(cur_frame.shape) != 3:
|
||||
# maybe end of the video
|
||||
break
|
||||
clip.append(cur_frame)
|
||||
else:
|
||||
for i in range(self.sample_length - 1):
|
||||
if self.sample_step > 1:
|
||||
video_reader.skip_frames(self.sample_step - 1)
|
||||
cur_frame = video_reader.next().asnumpy()
|
||||
if len(cur_frame.shape) != 3:
|
||||
# maybe end of the video
|
||||
break
|
||||
clip.append(cur_frame)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# if clip needs more frames, simply duplicate the last frame in the clip.
|
||||
while len(clip) < self.sample_length:
|
||||
clip.append(clip[-1].copy())
|
||||
|
||||
return clip
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Return:
|
||||
clips (torch.tensor), label (int)
|
||||
"""
|
||||
record = self.video_records[idx]
|
||||
video_reader = decord.VideoReader(
|
||||
"{}.{}".format(os.path.join(self.video_dir, record.path), self.video_ext),
|
||||
# TODO try to add `ctx=decord.ndarray.gpu(0) or .cuda(0)`
|
||||
)
|
||||
record._num_frames = len(video_reader)
|
||||
|
||||
offsets = self._sample_indices(record)
|
||||
clips = np.array([self._get_frames(video_reader, o) for o in offsets])
|
||||
|
||||
if self.num_segments == 1:
|
||||
# [T, H, W, C] -> [C, T, H, W]
|
||||
return self.transforms(torch.from_numpy(clips[0])), record.label
|
||||
else:
|
||||
# [S, T, H, W, C] -> [S, C, T, H, W]
|
||||
return (
|
||||
torch.stack([
|
||||
self.transforms(torch.from_numpy(c)) for c in clips
|
||||
]),
|
||||
record.label
|
||||
)
|
||||
|
||||
|
||||
def show_batch(batch, sample_length, mean=DEFAULT_MEAN, std=DEFAULT_STD):
|
||||
"""
|
||||
Args:
|
||||
batch (list[torch.tensor]): List of sample (clip) tensors
|
||||
sample_length (int): Number of frames to show for each sample
|
||||
mean (tuple): Normalization mean
|
||||
std (tuple): Normalization std-dev
|
||||
"""
|
||||
batch_size = len(batch)
|
||||
plt.tight_layout()
|
||||
fig, axs = plt.subplots(
|
||||
batch_size,
|
||||
sample_length,
|
||||
figsize=(4 * sample_length, 3 * batch_size)
|
||||
)
|
||||
|
||||
for i, ax in enumerate(axs):
|
||||
if batch_size == 1:
|
||||
clip = batch[0]
|
||||
else:
|
||||
clip = batch[i]
|
||||
clip = Rearrange("c t h w -> t c h w")(clip)
|
||||
if not isinstance(ax, np.ndarray):
|
||||
ax = [ax]
|
||||
for j, a in enumerate(ax):
|
||||
a.axis("off")
|
||||
a.imshow(
|
||||
np.moveaxis(
|
||||
denormalize(
|
||||
clip[j],
|
||||
mean,
|
||||
std,
|
||||
).numpy(),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,401 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
AMP_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
AMP_AVAILABLE = False
|
||||
import torch
|
||||
import torch.cuda as cuda
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from vu.utils import Config
|
||||
from vu.data import (
|
||||
DEFAULT_MEAN,
|
||||
DEFAULT_STD,
|
||||
show_batch as _show_batch,
|
||||
VideoDataset,
|
||||
)
|
||||
|
||||
from vu.utils.metrics import accuracy, AverageMeter
|
||||
|
||||
# From https://github.com/moabitcoin/ig65m-pytorch
|
||||
TORCH_R2PLUS1D = "moabitcoin/ig65m-pytorch"
|
||||
MODELS = {
|
||||
# model: output classes
|
||||
'r2plus1d_34_32_ig65m': 359,
|
||||
'r2plus1d_34_32_kinetics': 400,
|
||||
'r2plus1d_34_8_ig65m': 487,
|
||||
'r2plus1d_34_8_kinetics': 400,
|
||||
}
|
||||
|
||||
|
||||
class R2Plus1D(object):
|
||||
def __init__(self, cfgs):
|
||||
self.configs = Config(cfgs)
|
||||
self.train_ds, self.valid_ds = self.load_datasets(self.configs)
|
||||
self.model = self.init_model(
|
||||
self.configs.sample_length,
|
||||
self.configs.base_model,
|
||||
self.configs.num_classes
|
||||
)
|
||||
self.model_name = "r2plus1d_34_{}_{}".format(self.configs.sample_length, self.configs.base_model)
|
||||
|
||||
@staticmethod
|
||||
def init_model(sample_length, base_model, num_classes=None):
|
||||
if sample_length not in (8, 32):
|
||||
raise ValueError(
|
||||
"Not supported input frame length {}. Should be 8 or 32"
|
||||
.format(sample_length)
|
||||
)
|
||||
if base_model not in ('ig65m', 'kinetics'):
|
||||
raise ValueError(
|
||||
"Not supported model {}. Should be 'ig65m' or 'kinetics'"
|
||||
.format(base_model)
|
||||
)
|
||||
|
||||
model_name = "r2plus1d_34_{}_{}".format(sample_length, base_model)
|
||||
|
||||
print("Loading {} model".format(model_name))
|
||||
|
||||
model = torch.hub.load(
|
||||
TORCH_R2PLUS1D, model_name, num_classes=MODELS[model_name], pretrained=True
|
||||
)
|
||||
|
||||
# Replace head
|
||||
if num_classes is not None:
|
||||
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def load_datasets(cfgs):
|
||||
"""Load VideoDataset
|
||||
|
||||
Args:
|
||||
cfgs (dict or Config): Dataset configuration. For validation dataset,
|
||||
data augmentation such as random shift and temporal jitter is not used.
|
||||
|
||||
Return:
|
||||
VideoDataset, VideoDataset: Train and validation datasets.
|
||||
If split file is not provided, returns None.
|
||||
"""
|
||||
cfgs = Config(cfgs)
|
||||
|
||||
train_split = cfgs.get('train_split', None)
|
||||
train_ds = None if train_split is None else VideoDataset(
|
||||
split_file=train_split,
|
||||
video_dir=cfgs.video_dir,
|
||||
num_segments=1,
|
||||
sample_length=cfgs.sample_length,
|
||||
sample_step=cfgs.get('temporal_jitter_step', cfgs.get('sample_step', 1)),
|
||||
input_size=112,
|
||||
im_scale=cfgs.get('im_scale', 128),
|
||||
resize_keep_ratio=cfgs.get('resize_keep_ratio', True),
|
||||
mean=cfgs.get('mean', DEFAULT_MEAN),
|
||||
std=cfgs.get('std', DEFAULT_STD),
|
||||
random_shift=cfgs.get('random_shift', True),
|
||||
temporal_jitter=True if cfgs.get('temporal_jitter_step', 0) > 0 else False,
|
||||
flip_ratio=cfgs.get('flip_ratio', 0.5),
|
||||
random_crop=cfgs.get('random_crop', True),
|
||||
random_crop_scales=cfgs.get('random_crop_scales', (0.6, 1.0)),
|
||||
video_ext=cfgs.video_ext,
|
||||
)
|
||||
|
||||
valid_split = cfgs.get('valid_split', None)
|
||||
valid_ds = None if valid_split is None else VideoDataset(
|
||||
split_file=valid_split,
|
||||
video_dir=cfgs.video_dir,
|
||||
num_segments=1,
|
||||
sample_length=cfgs.sample_length,
|
||||
sample_step=cfgs.get('sample_step', 1),
|
||||
input_size=112,
|
||||
im_scale=cfgs.get('im_scale', 128),
|
||||
resize_keep_ratio=True,
|
||||
mean=cfgs.get('mean', DEFAULT_MEAN),
|
||||
std=cfgs.get('std', DEFAULT_STD),
|
||||
random_shift=False,
|
||||
temporal_jitter=False,
|
||||
flip_ratio=0.0,
|
||||
random_crop=False, # == Center crop
|
||||
random_crop_scales=None,
|
||||
video_ext=cfgs.video_ext,
|
||||
)
|
||||
|
||||
return train_ds, valid_ds
|
||||
|
||||
def show_batch(self, which_data='train', num_samples=1):
|
||||
"""Plot first few samples in the datasets"""
|
||||
if which_data == 'train':
|
||||
batch = [self.train_ds[i][0] for i in range(num_samples)]
|
||||
elif which_data == 'valid':
|
||||
batch = [self.valid_ds[i][0] for i in range(num_samples)]
|
||||
else:
|
||||
raise ValueError("Unknown data type {}".format(which_data))
|
||||
_show_batch(
|
||||
batch,
|
||||
self.configs.sample_length,
|
||||
mean=self.configs.get('mean', DEFAULT_MEAN),
|
||||
std=self.configs.get('std', DEFAULT_STD),
|
||||
)
|
||||
|
||||
def freeze(self):
|
||||
"""Freeze model except the last layer"""
|
||||
self._set_requires_grad(False)
|
||||
for param in self.model.fc.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def unfreeze(self):
|
||||
self._set_requires_grad(True)
|
||||
|
||||
def _set_requires_grad(self, requires_grad=True):
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def fit(self, train_cfgs):
|
||||
train_cfgs = Config(train_cfgs)
|
||||
|
||||
model_dir = train_cfgs.get('model_dir', "checkpoints")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
if cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
num_devices = cuda.device_count()
|
||||
# Look for the optimal set of algorithms to use in cudnn. Use this only with fixed-size inputs.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
num_devices = 1
|
||||
|
||||
data_loaders = {}
|
||||
if self.train_ds is not None:
|
||||
data_loaders['train'] = DataLoader(
|
||||
self.train_ds,
|
||||
batch_size=train_cfgs.get('batch_size', 8) * num_devices,
|
||||
shuffle=True,
|
||||
num_workers=0, # Torch 1.2 has a bug when num-workers > 0 (0 means run a main-processor worker)
|
||||
pin_memory=True,
|
||||
)
|
||||
if self.valid_ds is not None:
|
||||
data_loaders['valid'] = DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size=train_cfgs.get('batch_size', 8) * num_devices,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Move model to gpu before constructing optimizers and amp.initialize
|
||||
self.model.to(device)
|
||||
|
||||
named_params_to_update = {}
|
||||
total_params = 0
|
||||
for name, param in self.model.named_parameters():
|
||||
total_params += 1
|
||||
if param.requires_grad:
|
||||
named_params_to_update[name] = param
|
||||
|
||||
print("Params to learn:")
|
||||
if len(named_params_to_update) == total_params:
|
||||
print("\tfull network")
|
||||
else:
|
||||
for name in named_params_to_update:
|
||||
print("\t{}".format(name))
|
||||
|
||||
optimizer = optim.SGD(
|
||||
list(named_params_to_update.values()),
|
||||
lr=train_cfgs.lr,
|
||||
momentum=train_cfgs.momentum,
|
||||
weight_decay=train_cfgs.weight_decay,
|
||||
)
|
||||
|
||||
# Use mixed-precision if available
|
||||
# Currently, only O1 works with DataParallel: See issues https://github.com/NVIDIA/apex/issues/227
|
||||
if train_cfgs.get('mixed_prec', False) and AMP_AVAILABLE:
|
||||
# 'O0': Full FP32, 'O1': Conservative, 'O2': Standard, 'O3': Full FP16
|
||||
self.model, optimizer = amp.initialize(
|
||||
self.model,
|
||||
optimizer,
|
||||
opt_level="O1",
|
||||
loss_scale="dynamic",
|
||||
# keep_batchnorm_fp32=True doesn't work on 'O1'
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = None
|
||||
warmup_pct = train_cfgs.get('warmup_pct', None)
|
||||
lr_decay_steps = train_cfgs.get('lr_decay_steps', None)
|
||||
if warmup_pct is not None:
|
||||
# Use warmup with the one-cycle policy
|
||||
lr_decay_total_steps = train_cfgs.epochs if lr_decay_steps is None else lr_decay_steps
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=train_cfgs.lr,
|
||||
total_steps=lr_decay_total_steps,
|
||||
pct_start=train_cfgs.get('warmup_pct', 0.3),
|
||||
base_momentum=0.9*train_cfgs.momentum,
|
||||
max_momentum=train_cfgs.momentum,
|
||||
final_div_factor=1/train_cfgs.get('lr_decay_factor', 0.0001),
|
||||
)
|
||||
elif lr_decay_steps is not None:
|
||||
lr_decay_total_steps = train_cfgs.epochs
|
||||
# Simple step-decay
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=lr_decay_steps,
|
||||
gamma=train_cfgs.get('lr_decay_factor', 0.1),
|
||||
)
|
||||
|
||||
# DataParallel after amp.initialize
|
||||
if num_devices > 1:
|
||||
model = nn.DataParallel(self.model)
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
for e in range(1, train_cfgs.epochs + 1):
|
||||
print("Epoch {} ==========".format(e))
|
||||
if scheduler is not None:
|
||||
print("lr={}".format(scheduler.get_lr()))
|
||||
|
||||
self.train_an_epoch(
|
||||
model,
|
||||
data_loaders,
|
||||
device,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_steps=train_cfgs.grad_steps,
|
||||
mixed_prec=train_cfgs.mixed_prec,
|
||||
)
|
||||
if scheduler is not None and e < lr_decay_total_steps:
|
||||
scheduler.step()
|
||||
|
||||
self.save(
|
||||
os.path.join(
|
||||
model_dir,
|
||||
"{model_name}_{epoch}.pt".format(
|
||||
model_name=train_cfgs.get('model_name', self.model_name),
|
||||
epoch=str(e).zfill(3)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def train_an_epoch(
|
||||
model,
|
||||
data_loaders,
|
||||
device,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_steps=1,
|
||||
mixed_prec=False,
|
||||
):
|
||||
"""Train / validate a model for one epoch.
|
||||
|
||||
:param model:
|
||||
:param data_loaders: dict {'train': train_dl, 'valid': valid_dl}
|
||||
:param device:
|
||||
:param criterion:
|
||||
:param optimizer:
|
||||
:param grad_steps: If > 1, use gradient accumulation. Useful for larger batching
|
||||
:param mixed_prec: If True, use FP16 + FP32 mixed precision via NVIDIA apex.amp
|
||||
:return: dict {
|
||||
'train/time': batch_time.avg,
|
||||
'train/loss': losses.avg,
|
||||
'train/top1': top1.avg,
|
||||
'train/top5': top5.avg,
|
||||
'valid/time': ...
|
||||
}
|
||||
"""
|
||||
assert "train" in data_loaders
|
||||
if mixed_prec and not AMP_AVAILABLE:
|
||||
warnings.warn(
|
||||
"NVIDIA apex module is not installed. Cannot use mixed-precision."
|
||||
)
|
||||
|
||||
result = OrderedDict()
|
||||
for phase in ["train", "valid"]:
|
||||
# switch mode
|
||||
if phase == "train":
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
dl = data_loaders[phase]
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
end = time.time()
|
||||
for step, (inputs, target) in enumerate(dl, start=1):
|
||||
inputs = inputs.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
with torch.set_grad_enabled(phase == "train"):
|
||||
# compute output
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(outputs, target, topk=(1, 5))
|
||||
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1[0], inputs.size(0))
|
||||
top5.update(prec5[0], inputs.size(0))
|
||||
|
||||
if phase == "train":
|
||||
# make the accumulated gradient to be the same scale as without the accumulation
|
||||
loss = loss / grad_steps
|
||||
|
||||
if mixed_prec and AMP_AVAILABLE:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if step % grad_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
print(
|
||||
"{} took {:.2f} sec: loss = {:.4f}, top1_acc = {:.4f}, top5_acc = {:.4f}".format(
|
||||
phase, batch_time.sum, losses.avg, top1.avg, top5.avg
|
||||
)
|
||||
)
|
||||
result["{}/time".format(phase)] = batch_time.sum
|
||||
result["{}/loss".format(phase)] = losses.avg
|
||||
result["{}/top1".format(phase)] = top1.avg
|
||||
result["{}/top5".format(phase)] = top5.avg
|
||||
|
||||
return result
|
||||
|
||||
def save(self, model_path):
|
||||
torch.save(
|
||||
self.model.state_dict(),
|
||||
model_path
|
||||
)
|
||||
|
||||
def load(self, model_name, model_dir="checkpoints"):
|
||||
"""
|
||||
TODO accept epoch. If None, load the latest model.
|
||||
:param model_name: Model name format should be 'name_0EE' where E is the epoch
|
||||
:param model_dir: By default, 'checkpoints'
|
||||
:return:
|
||||
"""
|
||||
self.model.load_state_dict(torch.load(
|
||||
os.path.join(model_dir, "{}.pt".format(model_name))
|
||||
))
|
|
@ -0,0 +1 @@
|
|||
from .common import Config, system_info
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.cuda as cuda
|
||||
import torchvision
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, config=None, **extras):
|
||||
"""Dictionary wrapper to access keys as attributes.
|
||||
|
||||
Args:
|
||||
config (dict or Config): Configurations
|
||||
extras (kwargs): Extra configurations
|
||||
|
||||
Examples:
|
||||
>>> cfg = Config({'lr': 0.01}, momentum=0.95)
|
||||
or
|
||||
>>> cfg = Config({'lr': 0.01, 'momentum': 0.95})
|
||||
then, use as follows:
|
||||
>>> print(cfg.lr, cfg.momentum)
|
||||
"""
|
||||
if config is not None:
|
||||
if isinstance(config, dict):
|
||||
for k in config:
|
||||
setattr(self, k, config[k])
|
||||
elif isinstance(config, self.__class__):
|
||||
self.__dict__ = config.__dict__.copy()
|
||||
else:
|
||||
raise ValueError("Unknown config")
|
||||
|
||||
for k, v in extras.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def get(self, key, default):
|
||||
return getattr(self, key, default)
|
||||
|
||||
|
||||
def system_info():
|
||||
print(sys.version, "\n")
|
||||
print("PyTorch {}".format(torch.__version__), "\n")
|
||||
print("Torch-vision {}".format(torchvision.__version__), "\n")
|
||||
print("Available devices:")
|
||||
if cuda.is_available():
|
||||
for i in range(cuda.device_count()):
|
||||
print("{}: {}".format(i, cuda.get_device_name(i)))
|
||||
else:
|
||||
print("CPUs")
|
|
@ -0,0 +1,118 @@
|
|||
# Referred torchvision
|
||||
# https://github.com/pytorch/vision/blob/master/torchvision/transforms/_functional_video.py
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _is_tensor_video_clip(clip):
|
||||
if not torch.is_tensor(clip):
|
||||
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
|
||||
|
||||
if not clip.ndimension() == 4:
|
||||
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def crop(clip, i, j, h, w):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
"""
|
||||
assert len(clip.size()) == 4, "clip should be a 4D tensor"
|
||||
return clip[..., i:i + h, j:j + w]
|
||||
|
||||
|
||||
def resize(clip, target_size, interpolation_mode):
|
||||
assert len(target_size) == 2, "target size should be tuple (height, width)"
|
||||
return torch.nn.functional.interpolate(
|
||||
clip, size=target_size, mode=interpolation_mode
|
||||
)
|
||||
|
||||
|
||||
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
||||
"""
|
||||
Do spatial cropping and resizing to the video clip
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
||||
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
||||
h (int): Height of the cropped region.
|
||||
w (int): Width of the cropped region.
|
||||
size (tuple(int, int)): height and width of resized clip
|
||||
Returns:
|
||||
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
clip = crop(clip, i, j, h, w)
|
||||
clip = resize(clip, size, interpolation_mode)
|
||||
return clip
|
||||
|
||||
|
||||
def center_crop(clip, crop_size):
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
th, tw = crop_size
|
||||
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
|
||||
|
||||
i = int(round((h - th) / 2.0))
|
||||
j = int(round((w - tw) / 2.0))
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def to_tensor(clip):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimenions of clip tensor
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
if not clip.dtype == torch.uint8:
|
||||
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
||||
return clip.float().permute(3, 0, 1, 2) / 255.0
|
||||
|
||||
|
||||
def normalize(clip, mean, std, inplace=False):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
mean (tuple): pixel RGB mean. Size is (3)
|
||||
std (tuple): pixel standard deviation. Size is (3)
|
||||
Returns:
|
||||
normalized clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
if not inplace:
|
||||
clip = clip.clone()
|
||||
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
||||
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
||||
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
||||
return clip
|
||||
|
||||
|
||||
def hflip(clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
||||
Returns:
|
||||
flipped clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
|
||||
return clip.flip((-1))
|
||||
|
||||
|
||||
def denormalize(clip, mean, std):
|
||||
"""Denormalize a sample who was normalized by (x - mean) / std
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be de-normalized
|
||||
mean (tuple): pixel RGB mean. Size is (3)
|
||||
std (tuple): pixel standard deviation. Size is (3)
|
||||
Returns:
|
||||
"""
|
||||
result = clip.clone()
|
||||
for t, m, s in zip(result, mean, std):
|
||||
t.mul_(s).add_(m)
|
||||
return result
|
|
@ -0,0 +1,38 @@
|
|||
# From https://github.com/feiyunzhang/i3d-non-local-pytorch/blob/master/main.py
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
|
@ -0,0 +1,263 @@
|
|||
# Referred torchvision:
|
||||
# https://github.com/pytorch/vision/blob/master/torchvision/transforms/_transforms_video.py
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from . import functional_video as F
|
||||
|
||||
__all__ = [
|
||||
"ResizeVideo",
|
||||
"RandomCropVideo",
|
||||
"RandomResizedCropVideo",
|
||||
"CenterCropVideo",
|
||||
"NormalizeVideo",
|
||||
"ToTensorVideo",
|
||||
"RandomHorizontalFlipVideo",
|
||||
]
|
||||
|
||||
|
||||
class ResizeVideo(object):
|
||||
def __init__(self, size, keep_ratio=True, interpolation_mode="bilinear"):
|
||||
if isinstance(size, tuple):
|
||||
assert len(size) == 2, "size should be tuple (height, width)"
|
||||
self.size = size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
size, scale = None, None
|
||||
if isinstance(self.size, numbers.Number):
|
||||
if self.keep_ratio:
|
||||
scale = self.size / min(clip.shape[-2:])
|
||||
else:
|
||||
size = (int(self.size), int(self.size))
|
||||
else:
|
||||
if self.keep_ratio:
|
||||
scale = min(self.size[0] / clip.shape[-2], self.size[1] / clip.shape[-1], )
|
||||
else:
|
||||
size = self.size
|
||||
|
||||
return nn.functional.interpolate(
|
||||
clip, size=size, scale_factor=scale,
|
||||
mode=self.interpolation_mode, align_corners=False
|
||||
)
|
||||
|
||||
|
||||
class RandomCropVideo(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W).
|
||||
Returns:
|
||||
torch.tensor: randomly cropped/resized video clip.
|
||||
size is (C, T, OH, OW)
|
||||
"""
|
||||
i, j, h, w = self.get_params(clip, self.size)
|
||||
return F.crop(clip, i, j, h, w)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
||||
|
||||
@staticmethod
|
||||
def get_params(clip, output_size):
|
||||
"""Get parameters for ``crop`` for a random crop.
|
||||
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W).
|
||||
output_size (tuple): Expected output size of the crop.
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
||||
"""
|
||||
w, h = clip.shape[3], clip.shape[2]
|
||||
th, tw = output_size
|
||||
if w == tw and h == th:
|
||||
return 0, 0, h, w
|
||||
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
return i, j, th, tw
|
||||
|
||||
|
||||
class RandomResizedCropVideo(object):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
assert len(size) == 2, "size should be tuple (height, width)"
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W).
|
||||
Returns:
|
||||
torch.tensor: randomly cropped/resized video clip.
|
||||
size is (C, T, H, W)
|
||||
"""
|
||||
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
|
||||
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format(
|
||||
self.size, self.interpolation_mode, self.scale, self.ratio
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_params(clip, scale, ratio):
|
||||
"""Get parameters for ``crop`` for a random sized crop.
|
||||
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
scale (tuple): range of size of the origin size cropped
|
||||
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||
sized crop.
|
||||
"""
|
||||
_w, _h = clip.shape[3], clip.shape[2]
|
||||
area = _w * _h
|
||||
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(*scale) * area
|
||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w <= _w and h <= _h:
|
||||
i = random.randint(0, _h - h)
|
||||
j = random.randint(0, _w - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = _w / _h
|
||||
if in_ratio < min(ratio):
|
||||
w = _w
|
||||
h = int(round(w / min(ratio)))
|
||||
elif in_ratio > max(ratio):
|
||||
h = _h
|
||||
w = int(round(h * max(ratio)))
|
||||
else: # whole image
|
||||
w = _w
|
||||
h = _h
|
||||
i = (_h - h) // 2
|
||||
j = (_w - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
|
||||
class CenterCropVideo(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
||||
Returns:
|
||||
torch.tensor: central cropping of video clip. Size is
|
||||
(C, T, size, size)
|
||||
"""
|
||||
return F.center_crop(clip, self.size)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
||||
|
||||
|
||||
class NormalizeVideo(object):
|
||||
"""
|
||||
Normalize the video clip by mean subtraction and division by standard deviation
|
||||
Args:
|
||||
mean (3-tuple): pixel RGB mean
|
||||
std (3-tuple): pixel RGB standard deviation
|
||||
inplace (boolean): whether do in-place normalization
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
|
||||
"""
|
||||
return F.normalize(clip, self.mean, self.std, self.inplace)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format(
|
||||
self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToTensorVideo(object):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimenions of clip tensor
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
||||
"""
|
||||
return F.to_tensor(clip)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class RandomHorizontalFlipVideo(object):
|
||||
"""
|
||||
Flip the video clip along the horizonal direction with a given probability
|
||||
Args:
|
||||
p (float): probability of the clip being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
Return:
|
||||
clip (torch.tensor): Size is (C, T, H, W)
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
clip = F.hflip(clip)
|
||||
return clip
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(p={0})".format(self.p)
|
Загрузка…
Ссылка в новой задаче