This commit is contained in:
Bin Xiao 2021-05-25 13:44:52 -07:00
Родитель f42d58b109
Коммит 56984edeed
35 изменённых файлов: 14270 добавлений и 33 удалений

130
README.md
Просмотреть файл

@ -1,15 +1,129 @@
# Project
# Introduction
This is an official implementation of [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808). We present a new architecture, named Convolutional vision Transformers (CvT), that improves Vision Transformers (ViT) in performance and efficienty by introducing convolutions into ViT to yield the best of both disignes. This is accomplished through two primary modifications: a hierarchy of Transformers containing a new convolutional token embedding, and a convolutional Transformer block leveraging a convolutional projection. These changes introduce desirable properties of convolutional neural networks (CNNs) to the ViT architecture (e.g. shift, scale, and distortion invariance) while maintaining the merits of Transformers (e.g. dynamic attention, global context, and better generalization). We validate CvT by conducting extensive experiments, showing that this approach achieves state-of-the-art performance over other Vision Transformers and ResNets on ImageNet-1k, with fewer parameters and lower FLOPs. In addition, performance gains are maintained when pretrained on larger dataset (e.g. ImageNet-22k) and fine-tuned to downstream tasks. Pre-trained on ImageNet-22k, our CvT-W24 obtains a top-1 accuracy of 87.7% on the ImageNet-1k val set. Finally, our results show that the positional encoding, a crucial component in existing Vision Transformers, can be safely removed in our model, simplifying the design for higher resolution vision tasks.
> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
![](figures/pipeline.svg)
As the maintainer of this project, please make a few updates:
# Main results
## Models pre-trained on ImageNet-1k
| Model | Resolution | Param | GFLOPs | Top-1 |
|--------|------------|-------|--------|-------|
| CvT-13 | 224x224 | 20M | 4.5 | 81.6 |
| CvT-21 | 224x224 | 32M | 7.1 | 82.5 |
| CvT-13 | 384x384 | 20M | 16.3 | 83.0 |
| CvT-32 | 384x384 | 32M | 24.9 | 83.3 |
- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
## Models pre-trained on ImageNet-22k
| Model | Resolution | Param | GFLOPs | Top-1 |
|---------|------------|-------|--------|-------|
| CvT-13 | 384x384 | 20M | 16.3 | 83.3 |
| CvT-32 | 384x384 | 32M | 24.9 | 84.9 |
| CvT-W24 | 384x384 | 277M | 193.2 | 87.6 |
You can download all the models from our [model zoo](https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al).
# Quick start
## Installation
Assuming that you have installed PyTroch and TorchVision, if not, please follow the [officiall instruction](https://pytorch.org/) to install them firstly.
Intall the dependencies using cmd:
``` sh
python -m pip install -r requirements.txt --user -q
```
The code is developed and tested using pytorch 1.7.1. Other versions of pytorch are not fully tested.
## Data preparation
Please prepare the data as following:
``` sh
|-DATASET
|-imagenet
|-train
| |-class1
| | |-img1.jpg
| | |-img2.jpg
| | |-...
| |-class2
| | |-img3.jpg
| | |-...
| |-class3
| | |-img4.jpg
| | |-...
| |-...
|-val
|-class1
| |-img5.jpg
| |-...
|-class2
| |-img6.jpg
| |-...
|-class3
| |-img7.jpg
| |-...
|-...
```
## Run
Each experiment is defined by a yaml config file, which is saved under the directory of `experiments`. The directory of `experiments` has a tree structure like this:
``` sh
experiments
|-{DATASET_A}
| |-{ARCH_A}
| |-{ARCH_B}
|-{DATASET_B}
| |-{ARCH_A}
| |-{ARCH_B}
|-{DATASET_C}
| |-{ARCH_A}
| |-{ARCH_B}
|-...
```
We provide a `run.sh` script for running jobs in local machine.
``` sh
Usage: run.sh [run_options]
Options:
-g|--gpus <1> - number of gpus to be used
-t|--job-type <aml> - job type (train|test)
-p|--port <9000> - master port
-i|--install-deps - If install dependencies (default: False)
```
### Training on local machine
``` sh
bash run.sh -g 8 -t train --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml
```
You can also modify the config paramters by the command line. For example, if you want to change the lr rate to 0.1, you can run the command:
``` sh
bash run.sh -g 8 -t train --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml TRAIN.LR 0.1
```
Notes:
- The checkpoint, model, and log files will be saved in OUTPUT/{dataset}/{training config} by default.
### Testing pre-trained models
``` sh
bash run.sh -t test --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml TEST.MODEL_FILE ${PRETRAINED_MODLE_FILE}
```
# Citation
If you find this work or code is helpful in your research, please cite:
```
@article{wu2021cvt,
title={Cvt: Introducing convolutions to vision transformers},
author={Wu, Haiping and Xiao, Bin and Codella, Noel and Liu, Mengchen and Dai, Xiyang and Yuan, Lu and Zhang, Lei},
journal={arXiv preprint arXiv:2103.15808},
year={2021}
}
```
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a

Просмотреть файл

@ -1,25 +0,0 @@
# TODO: The maintainer of this repo has not yet edited this file
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
# Support
## How to file issues and get help
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
## Microsoft Support Policy
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.

Просмотреть файл

@ -0,0 +1,83 @@
OUTPUT_DIR: 'OUTPUT/'
WORKERS: 6
PRINT_FREQ: 500
AMP:
ENABLED: true
MODEL:
NAME: cls_cvt
SPEC:
INIT: 'trunc_norm'
NUM_STAGES: 3
PATCH_SIZE: [7, 3, 3]
PATCH_STRIDE: [4, 2, 2]
PATCH_PADDING: [2, 1, 1]
DIM_EMBED: [64, 192, 384]
NUM_HEADS: [1, 3, 6]
DEPTH: [1, 2, 10]
MLP_RATIO: [4.0, 4.0, 4.0]
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
DROP_RATE: [0.0, 0.0, 0.0]
DROP_PATH_RATE: [0.0, 0.0, 0.1]
QKV_BIAS: [True, True, True]
CLS_TOKEN: [False, False, True]
POS_EMBED: [False, False, False]
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
KERNEL_QKV: [3, 3, 3]
PADDING_KV: [1, 1, 1]
STRIDE_KV: [2, 2, 2]
PADDING_Q: [1, 1, 1]
STRIDE_Q: [1, 1, 1]
AUG:
MIXUP_PROB: 1.0
MIXUP: 0.8
MIXCUT: 1.0
TIMM_AUG:
USE_LOADER: true
RE_COUNT: 1
RE_MODE: pixel
RE_SPLIT: false
RE_PROB: 0.25
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
HFLIP: 0.5
VFLIP: 0.0
COLOR_JITTER: 0.4
INTERPOLATION: bicubic
LOSS:
LABEL_SMOOTHING: 0.1
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'DATASET/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
TEST:
BATCH_SIZE_PER_GPU: 32
IMAGE_SIZE: [224, 224]
MODEL_FILE: ''
INTERPOLATION: 3
TRAIN:
BATCH_SIZE_PER_GPU: 256
LR: 0.00025
IMAGE_SIZE: [224, 224]
BEGIN_EPOCH: 0
END_EPOCH: 300
LR_SCHEDULER:
METHOD: 'timm'
ARGS:
sched: 'cosine'
warmup_epochs: 5
warmup_lr: 0.000001
min_lr: 0.00001
cooldown_epochs: 10
decay_rate: 0.1
OPTIMIZER: adamW
WD: 0.05
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
SHUFFLE: true
DEBUG:
DEBUG: false

Просмотреть файл

@ -0,0 +1,84 @@
OUTPUT_DIR: 'OUTPUT/'
WORKERS: 6
PRINT_FREQ: 500
AMP:
ENABLED: true
MODEL:
NAME: cls_cvt
SPEC:
INIT: 'trunc_norm'
NUM_STAGES: 3
PATCH_SIZE: [7, 3, 3]
PATCH_STRIDE: [4, 2, 2]
PATCH_PADDING: [2, 1, 1]
DIM_EMBED: [64, 192, 384]
NUM_HEADS: [1, 3, 6]
DEPTH: [1, 2, 10]
MLP_RATIO: [4.0, 4.0, 4.0]
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
DROP_RATE: [0.0, 0.0, 0.0]
DROP_PATH_RATE: [0.0, 0.0, 0.1]
QKV_BIAS: [True, True, True]
CLS_TOKEN: [False, False, True]
POS_EMBED: [False, False, False]
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
KERNEL_QKV: [3, 3, 3]
PADDING_KV: [1, 1, 1]
STRIDE_KV: [2, 2, 2]
PADDING_Q: [1, 1, 1]
STRIDE_Q: [1, 1, 1]
AUG:
MIXUP_PROB: 1.0
MIXUP: 0.8
MIXCUT: 1.0
TIMM_AUG:
USE_LOADER: true
RE_COUNT: 1
RE_MODE: pixel
RE_SPLIT: false
RE_PROB: 0.25
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
HFLIP: 0.5
VFLIP: 0.0
COLOR_JITTER: 0.4
INTERPOLATION: bicubic
LOSS:
LABEL_SMOOTHING: 0.1
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'DATASET/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
TEST:
BATCH_SIZE_PER_GPU: 32
IMAGE_SIZE: [384, 384]
CENTER_CROP: False
MODEL_FILE: ''
INTERPOLATION: 3
TRAIN:
BATCH_SIZE_PER_GPU: 256
LR: 0.00025
IMAGE_SIZE: [384, 384]
BEGIN_EPOCH: 0
END_EPOCH: 300
LR_SCHEDULER:
METHOD: 'timm'
ARGS:
sched: 'cosine'
warmup_epochs: 5
warmup_lr: 0.000001
min_lr: 0.00001
cooldown_epochs: 10
decay_rate: 0.1
OPTIMIZER: adamW
WD: 0.05
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
SHUFFLE: true
DEBUG:
DEBUG: false

Просмотреть файл

@ -0,0 +1,84 @@
OUTPUT_DIR: 'OUTPUT/'
WORKERS: 6
PRINT_FREQ: 500
AMP:
ENABLED: true
MODEL:
NAME: cls_cvt
SPEC:
INIT: 'trunc_norm'
NUM_STAGES: 3
PATCH_SIZE: [7, 3, 3]
PATCH_STRIDE: [4, 2, 2]
PATCH_PADDING: [2, 1, 1]
DIM_EMBED: [64, 192, 384]
NUM_HEADS: [1, 3, 6]
DEPTH: [1, 4, 16]
MLP_RATIO: [4.0, 4.0, 4.0]
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
DROP_RATE: [0.0, 0.0, 0.0]
DROP_PATH_RATE: [0.0, 0.0, 0.1]
QKV_BIAS: [True, True, True]
CLS_TOKEN: [False, False, True]
POS_EMBED: [False, False, False]
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
KERNEL_QKV: [3, 3, 3]
PADDING_KV: [1, 1, 1]
STRIDE_KV: [2, 2, 2]
PADDING_Q: [1, 1, 1]
STRIDE_Q: [1, 1, 1]
AUG:
MIXUP_PROB: 1.0
MIXUP: 0.8
MIXCUT: 1.0
TIMM_AUG:
USE_LOADER: false
RE_COUNT: 1
RE_MODE: pixel
RE_SPLIT: false
RE_PROB: 0.25
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
HFLIP: 0.5
VFLIP: 0.0
COLOR_JITTER: 0.4
INTERPOLATION: bicubic
LOSS:
LABEL_SMOOTHING: 0.1
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'DATASET/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
SAMPLER: repeated_aug
TEST:
BATCH_SIZE_PER_GPU: 32
IMAGE_SIZE: [224, 224]
MODEL_FILE: ''
INTERPOLATION: 3
TRAIN:
BATCH_SIZE_PER_GPU: 128
LR: 0.000125
IMAGE_SIZE: [224, 224]
BEGIN_EPOCH: 0
END_EPOCH: 300
LR_SCHEDULER:
METHOD: 'timm'
ARGS:
sched: 'cosine'
warmup_epochs: 5
warmup_lr: 0.000001
min_lr: 0.00001
cooldown_epochs: 10
decay_rate: 0.1
OPTIMIZER: adamW
WD: 0.1
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
SHUFFLE: true
DEBUG:
DEBUG: false

Просмотреть файл

@ -0,0 +1,84 @@
OUTPUT_DIR: 'OUTPUT/'
WORKERS: 6
PRINT_FREQ: 500
AMP:
ENABLED: true
MODEL:
NAME: cls_cvt
SPEC:
INIT: 'trunc_norm'
NUM_STAGES: 3
PATCH_SIZE: [7, 3, 3]
PATCH_STRIDE: [4, 2, 2]
PATCH_PADDING: [2, 1, 1]
DIM_EMBED: [64, 192, 384]
NUM_HEADS: [1, 3, 6]
DEPTH: [1, 4, 16]
MLP_RATIO: [4.0, 4.0, 4.0]
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
DROP_RATE: [0.0, 0.0, 0.0]
DROP_PATH_RATE: [0.0, 0.0, 0.1]
QKV_BIAS: [True, True, True]
CLS_TOKEN: [False, False, True]
POS_EMBED: [False, False, False]
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
KERNEL_QKV: [3, 3, 3]
PADDING_KV: [1, 1, 1]
STRIDE_KV: [2, 2, 2]
PADDING_Q: [1, 1, 1]
STRIDE_Q: [1, 1, 1]
AUG:
MIXUP_PROB: 1.0
MIXUP: 0.8
MIXCUT: 1.0
TIMM_AUG:
USE_LOADER: true
RE_COUNT: 1
RE_MODE: pixel
RE_SPLIT: false
RE_PROB: 0.25
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
HFLIP: 0.5
VFLIP: 0.0
COLOR_JITTER: 0.4
INTERPOLATION: bicubic
LOSS:
LABEL_SMOOTHING: 0.1
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'DATASET/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
TEST:
BATCH_SIZE_PER_GPU: 32
IMAGE_SIZE: [384, 384]
MODEL_FILE: ''
INTERPOLATION: 3
CENTER_CROP: False
TRAIN:
BATCH_SIZE_PER_GPU: 128
LR: 0.000125
IMAGE_SIZE: [384, 384]
BEGIN_EPOCH: 0
END_EPOCH: 300
LR_SCHEDULER:
METHOD: 'timm'
ARGS:
sched: 'cosine'
warmup_epochs: 5
warmup_lr: 0.000001
min_lr: 0.00001
cooldown_epochs: 10
decay_rate: 0.1
OPTIMIZER: adamW
WD: 0.1
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
SHUFFLE: true
DEBUG:
DEBUG: false

Просмотреть файл

@ -0,0 +1,84 @@
OUTPUT_DIR: 'OUTPUT/'
WORKERS: 6
PRINT_FREQ: 500
AMP:
ENABLED: true
MODEL:
NAME: cls_cvt
SPEC:
INIT: 'trunc_norm'
NUM_STAGES: 3
PATCH_SIZE: [7, 3, 3]
PATCH_STRIDE: [4, 2, 2]
PATCH_PADDING: [2, 1, 1]
DIM_EMBED: [192, 768, 1024]
NUM_HEADS: [3, 12, 16]
DEPTH: [2, 2, 20]
MLP_RATIO: [4.0, 4.0, 4.0]
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
DROP_RATE: [0.0, 0.0, 0.0]
DROP_PATH_RATE: [0.0, 0.0, 0.3]
QKV_BIAS: [True, True, True]
CLS_TOKEN: [False, False, True]
POS_EMBED: [False, False, False]
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
KERNEL_QKV: [3, 3, 3]
PADDING_KV: [1, 1, 1]
STRIDE_KV: [2, 2, 2]
PADDING_Q: [1, 1, 1]
STRIDE_Q: [1, 1, 1]
AUG:
MIXUP_PROB: 1.0
MIXUP: 0.8
MIXCUT: 1.0
TIMM_AUG:
USE_LOADER: true
RE_COUNT: 1
RE_MODE: pixel
RE_SPLIT: false
RE_PROB: 0.25
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
HFLIP: 0.5
VFLIP: 0.0
COLOR_JITTER: 0.4
INTERPOLATION: bicubic
LOSS:
LABEL_SMOOTHING: 0.1
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'DATASET/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
TEST:
BATCH_SIZE_PER_GPU: 32
IMAGE_SIZE: [384, 384]
MODEL_FILE: ''
INTERPOLATION: 3
CENTER_CROP: False
TRAIN:
BATCH_SIZE_PER_GPU: 128
LR: 0.000125
IMAGE_SIZE: [384, 384]
BEGIN_EPOCH: 0
END_EPOCH: 300
LR_SCHEDULER:
METHOD: 'timm'
ARGS:
sched: 'cosine'
warmup_epochs: 5
warmup_lr: 0.000001
min_lr: 0.00001
cooldown_epochs: 10
decay_rate: 0.1
OPTIMIZER: adamW
WD: 0.1
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
SHUFFLE: true
DEBUG:
DEBUG: false

11175
figures/pipeline.svg Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

После

Ширина:  |  Высота:  |  Размер: 923 KiB

4
lib/config/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,4 @@
from .default import _C as config
from .default import update_config
from .default import _update_config_from_file
from .default import save_config

203
lib/config/default.py Normal file
Просмотреть файл

@ -0,0 +1,203 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path as op
import yaml
from yacs.config import CfgNode as CN
from lib.utils.comm import comm
_C = CN()
_C.BASE = ['']
_C.NAME = ''
_C.DATA_DIR = ''
_C.DIST_BACKEND = 'nccl'
_C.GPUS = (0,)
# _C.LOG_DIR = ''
_C.MULTIPROCESSING_DISTRIBUTED = True
_C.OUTPUT_DIR = ''
_C.PIN_MEMORY = True
_C.PRINT_FREQ = 20
_C.RANK = 0
_C.VERBOSE = True
_C.WORKERS = 4
_C.MODEL_SUMMARY = False
_C.AMP = CN()
_C.AMP.ENABLED = False
_C.AMP.MEMORY_FORMAT = 'nchw'
# Cudnn related params
_C.CUDNN = CN()
_C.CUDNN.BENCHMARK = True
_C.CUDNN.DETERMINISTIC = False
_C.CUDNN.ENABLED = True
# common params for NETWORK
_C.MODEL = CN()
_C.MODEL.NAME = 'cls_hrnet'
_C.MODEL.INIT_WEIGHTS = True
_C.MODEL.PRETRAINED = ''
_C.MODEL.PRETRAINED_LAYERS = ['*']
_C.MODEL.NUM_CLASSES = 1000
_C.MODEL.SPEC = CN(new_allowed=True)
_C.LOSS = CN(new_allowed=True)
_C.LOSS.LABEL_SMOOTHING = 0.0
_C.LOSS.LOSS = 'softmax'
# DATASET related params
_C.DATASET = CN()
_C.DATASET.ROOT = ''
_C.DATASET.DATASET = 'imagenet'
_C.DATASET.TRAIN_SET = 'train'
_C.DATASET.TEST_SET = 'val'
_C.DATASET.DATA_FORMAT = 'jpg'
_C.DATASET.LABELMAP = ''
_C.DATASET.TRAIN_TSV_LIST = []
_C.DATASET.TEST_TSV_LIST = []
_C.DATASET.SAMPLER = 'default'
_C.DATASET.TARGET_SIZE = -1
# training data augmentation
_C.INPUT = CN()
_C.INPUT.MEAN = [0.485, 0.456, 0.406]
_C.INPUT.STD = [0.229, 0.224, 0.225]
# data augmentation
_C.AUG = CN()
_C.AUG.SCALE = (0.08, 1.0)
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
_C.AUG.GRAY_SCALE = 0.0
_C.AUG.GAUSSIAN_BLUR = 0.0
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
_C.AUG.MIXUP_PROB = 0.0
_C.AUG.MIXUP = 0.0
_C.AUG.MIXCUT = 0.0
_C.AUG.MIXCUT_MINMAX = []
_C.AUG.MIXUP_SWITCH_PROB = 0.5
_C.AUG.MIXUP_MODE = 'batch'
_C.AUG.MIXCUT_AND_MIXUP = False
_C.AUG.INTERPOLATION = 2
_C.AUG.TIMM_AUG = CN(new_allowed=True)
_C.AUG.TIMM_AUG.USE_LOADER = False
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
# train
_C.TRAIN = CN()
_C.TRAIN.AUTO_RESUME = True
_C.TRAIN.CHECKPOINT = ''
_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
_C.TRAIN.SCALE_LR = True
_C.TRAIN.LR = 0.001
_C.TRAIN.OPTIMIZER = 'sgd'
_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
_C.TRAIN.MOMENTUM = 0.9
_C.TRAIN.WD = 0.0001
_C.TRAIN.WITHOUT_WD_LIST = []
_C.TRAIN.NESTEROV = True
# for adam
_C.TRAIN.GAMMA1 = 0.99
_C.TRAIN.GAMMA2 = 0.0
_C.TRAIN.BEGIN_EPOCH = 0
_C.TRAIN.END_EPOCH = 100
_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
_C.TRAIN.SHUFFLE = True
_C.TRAIN.EVAL_BEGIN_EPOCH = 0
_C.TRAIN.DETECT_ANOMALY = False
_C.TRAIN.CLIP_GRAD_NORM = 0.0
_C.TRAIN.SAVE_ALL_MODELS = False
# testing
_C.TEST = CN()
# size of images for each device
_C.TEST.BATCH_SIZE_PER_GPU = 32
_C.TEST.CENTER_CROP = True
_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
_C.TEST.INTERPOLATION = 2
_C.TEST.MODEL_FILE = ''
_C.TEST.REAL_LABELS = False
_C.TEST.VALID_LABELS = ''
_C.FINETUNE = CN()
_C.FINETUNE.FINETUNE = False
_C.FINETUNE.USE_TRAIN_AUG = False
_C.FINETUNE.BASE_LR = 0.003
_C.FINETUNE.BATCH_SIZE = 512
_C.FINETUNE.EVAL_EVERY = 3000
_C.FINETUNE.TRAIN_MODE = True
# _C.FINETUNE.MODEL_FILE = ''
_C.FINETUNE.FROZEN_LAYERS = []
_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
# debug
_C.DEBUG = CN()
_C.DEBUG.DEBUG = False
def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
for cfg in yaml_cfg.setdefault('BASE', ['']):
if cfg:
_update_config_from_file(
config, op.join(op.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()
def update_config(config, args):
_update_config_from_file(config, args.cfg)
config.defrost()
config.merge_from_list(args.opts)
if config.TRAIN.SCALE_LR:
config.TRAIN.LR *= comm.world_size
file_name, _ = op.splitext(op.basename(args.cfg))
config.NAME = file_name + config.NAME
config.RANK = comm.rank
if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
if 'timm' == config.TRAIN.OPTIMIZER:
config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
aug = config.AUG
if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
aug.MIXUP_PROB = 1.0
config.freeze()
def save_config(cfg, path):
if comm.is_main_process():
with open(path, 'w') as f:
f.write(cfg.dump())
if __name__ == '__main__':
import sys
with open(sys.argv[1], 'w') as f:
print(_C, file=f)

25
lib/core/evaluate.py Normal file
Просмотреть файл

@ -0,0 +1,25 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if isinstance(output, list):
output = output[-1]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size).item())
return res

223
lib/core/function.py Normal file
Просмотреть файл

@ -0,0 +1,223 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
import torch
from timm.data import Mixup
from torch.cuda.amp import autocast
from core.evaluate import accuracy
from utils.comm import comm
def train_one_epoch(config, train_loader, model, criterion, optimizer, epoch,
output_dir, tb_log_dir, writer_dict, scaler=None):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
logging.info('=> switch to train mode')
model.train()
aug = config.AUG
mixup_fn = Mixup(
mixup_alpha=aug.MIXUP, cutmix_alpha=aug.MIXCUT,
cutmix_minmax=aug.MIXCUT_MINMAX if aug.MIXCUT_MINMAX else None,
prob=aug.MIXUP_PROB, switch_prob=aug.MIXUP_SWITCH_PROB,
mode=aug.MIXUP_MODE, label_smoothing=config.LOSS.LABEL_SMOOTHING,
num_classes=config.MODEL.NUM_CLASSES
) if aug.MIXUP_PROB > 0.0 else None
end = time.time()
for i, (x, y) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = x.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
if mixup_fn:
x, y = mixup_fn(x, y)
with autocast(enabled=config.AMP.ENABLED):
if config.AMP.ENABLED and config.AMP.MEMORY_FORMAT == 'nwhc':
x = x.contiguous(memory_format=torch.channels_last)
y = y.contiguous(memory_format=torch.channels_last)
outputs = model(x)
loss = criterion(outputs, y)
# compute gradient and do update step
optimizer.zero_grad()
is_second_order = hasattr(optimizer, 'is_second_order') \
and optimizer.is_second_order
scaler.scale(loss).backward(create_graph=is_second_order)
if config.TRAIN.CLIP_GRAD_NORM > 0.0:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(
model.parameters(), config.TRAIN.CLIP_GRAD_NORM
)
scaler.step(optimizer)
scaler.update()
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
if mixup_fn:
y = torch.argmax(y, dim=1)
prec1, prec5 = accuracy(outputs, y, (1, 5))
top1.update(prec1, x.size(0))
top5.update(prec5, x.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % config.PRINT_FREQ == 0:
msg = '=> Epoch[{0}][{1}/{2}]: ' \
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
'Speed {speed:.1f} samples/s\t' \
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
'Accuracy@1 {top1.val:.3f} ({top1.avg:.3f})\t' \
'Accuracy@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
epoch, i, len(train_loader),
batch_time=batch_time,
speed=x.size(0)/batch_time.val,
data_time=data_time, loss=losses, top1=top1, top5=top5)
logging.info(msg)
torch.cuda.synchronize()
if writer_dict and comm.is_main_process():
writer = writer_dict['writer']
global_steps = writer_dict['train_global_steps']
writer.add_scalar('train_loss', losses.avg, global_steps)
writer.add_scalar('train_top1', top1.avg, global_steps)
writer_dict['train_global_steps'] = global_steps + 1
@torch.no_grad()
def test(config, val_loader, model, criterion, output_dir, tb_log_dir,
writer_dict=None, distributed=False, real_labels=None,
valid_labels=None):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
logging.info('=> switch to eval mode')
model.eval()
end = time.time()
for i, (x, y) in enumerate(val_loader):
# compute output
x = x.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
outputs = model(x)
if valid_labels:
outputs = outputs[:, valid_labels]
loss = criterion(outputs, y)
if real_labels and not distributed:
real_labels.add_result(outputs)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
prec1, prec5 = accuracy(outputs, y, (1, 5))
top1.update(prec1, x.size(0))
top5.update(prec5, x.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
logging.info('=> synchronize...')
comm.synchronize()
top1_acc, top5_acc, loss_avg = map(
_meter_reduce if distributed else lambda x: x.avg,
[top1, top5, losses]
)
if real_labels and not distributed:
real_top1 = real_labels.get_accuracy(k=1)
real_top5 = real_labels.get_accuracy(k=5)
msg = '=> TEST using Reassessed labels:\t' \
'Error@1 {error1:.3f}%\t' \
'Error@5 {error5:.3f}%\t' \
'Accuracy@1 {top1:.3f}%\t' \
'Accuracy@5 {top5:.3f}%\t'.format(
top1=real_top1,
top5=real_top5,
error1=100-real_top1,
error5=100-real_top5
)
logging.info(msg)
if comm.is_main_process():
msg = '=> TEST:\t' \
'Loss {loss_avg:.4f}\t' \
'Error@1 {error1:.3f}%\t' \
'Error@5 {error5:.3f}%\t' \
'Accuracy@1 {top1:.3f}%\t' \
'Accuracy@5 {top5:.3f}%\t'.format(
loss_avg=loss_avg, top1=top1_acc,
top5=top5_acc, error1=100-top1_acc,
error5=100-top5_acc
)
logging.info(msg)
if writer_dict and comm.is_main_process():
writer = writer_dict['writer']
global_steps = writer_dict['valid_global_steps']
writer.add_scalar('valid_loss', loss_avg, global_steps)
writer.add_scalar('valid_top1', top1_acc, global_steps)
writer_dict['valid_global_steps'] = global_steps + 1
logging.info('=> switch to train mode')
model.train()
return top1_acc
def _meter_reduce(meter):
rank = comm.local_rank
meter_sum = torch.FloatTensor([meter.sum]).cuda(rank)
meter_count = torch.FloatTensor([meter.count]).cuda(rank)
torch.distributed.reduce(meter_sum, 0)
torch.distributed.reduce(meter_count, 0)
meter_avg = meter_sum / meter_count
return meter_avg.item()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

49
lib/core/loss.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
def linear_combination(x, y, epsilon):
return epsilon*x + (1-epsilon)*y
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction == 'mean' \
else loss.sum() if reduction == 'sum' else loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, epsilon=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
def forward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target):
loss = th.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean()
def build_criterion(config, train=True):
if config.AUG.MIXUP_PROB > 0.0 and config.LOSS.LOSS == 'softmax':
criterion = SoftTargetCrossEntropy() \
if train else nn.CrossEntropyLoss()
elif config.LOSS.LABEL_SMOOTHING > 0.0 and config.LOSS.LOSS == 'softmax':
criterion = LabelSmoothingCrossEntropy(config.LOSS.LABEL_SMOOTHING)
elif config.LOSS.LOSS == 'softmax':
criterion = nn.CrossEntropyLoss()
else:
raise ValueError('Unkown loss {}'.format(config.LOSS.LOSS))
return criterion

2
lib/dataset/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,2 @@
from .build import build_dataloader
from .imagenet.real_labels import RealLabelsImagenet

115
lib/dataset/build.py Normal file
Просмотреть файл

@ -0,0 +1,115 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
from timm.data import create_loader
import torch
import torch.utils.data
import torchvision.datasets as datasets
from .transformas import build_transforms
from .samplers import RASampler
def build_dataset(cfg, is_train):
dataset = None
if 'imagenet' in cfg.DATASET.DATASET:
dataset = _build_imagenet_dataset(cfg, is_train)
else:
raise ValueError('Unkown dataset: {}'.format(cfg.DATASET.DATASET))
return dataset
def _build_image_folder_dataset(cfg, is_train):
transforms = build_transforms(cfg, is_train)
dataset_name = cfg.DATASET.TRAIN_SET if is_train else cfg.DATASET.TEST_SET
dataset = datasets.ImageFolder(
os.path.join(cfg.DATASET.ROOT, dataset_name), transforms
)
logging.info(
'=> load samples: {}, is_train: {}'
.format(len(dataset), is_train)
)
return dataset
def _build_imagenet_dataset(cfg, is_train):
transforms = build_transforms(cfg, is_train)
dataset_name = cfg.DATASET.TRAIN_SET if is_train else cfg.DATASET.TEST_SET
dataset = datasets.ImageFolder(
os.path.join(cfg.DATASET.ROOT, dataset_name), transforms
)
return dataset
def build_dataloader(cfg, is_train=True, distributed=False):
if is_train:
batch_size_per_gpu = cfg.TRAIN.BATCH_SIZE_PER_GPU
shuffle = True
else:
batch_size_per_gpu = cfg.TEST.BATCH_SIZE_PER_GPU
shuffle = False
dataset = build_dataset(cfg, is_train)
if distributed:
if is_train and cfg.DATASET.SAMPLER == 'repeated_aug':
logging.info('=> use repeated aug sampler')
sampler = RASampler(dataset, shuffle=shuffle)
else:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=shuffle
)
shuffle = False
else:
sampler = None
if cfg.AUG.TIMM_AUG.USE_LOADER and is_train:
logging.info('=> use timm loader for training')
timm_cfg = cfg.AUG.TIMM_AUG
data_loader = create_loader(
dataset,
input_size=cfg.TRAIN.IMAGE_SIZE[0],
batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU,
is_training=True,
use_prefetcher=True,
no_aug=False,
re_prob=timm_cfg.RE_PROB,
re_mode=timm_cfg.RE_MODE,
re_count=timm_cfg.RE_COUNT,
re_split=timm_cfg.RE_SPLIT,
scale=cfg.AUG.SCALE,
ratio=cfg.AUG.RATIO,
hflip=timm_cfg.HFLIP,
vflip=timm_cfg.VFLIP,
color_jitter=timm_cfg.COLOR_JITTER,
auto_augment=timm_cfg.AUTO_AUGMENT,
num_aug_splits=0,
interpolation=timm_cfg.INTERPOLATION,
mean=cfg.INPUT.MEAN,
std=cfg.INPUT.STD,
num_workers=cfg.WORKERS,
distributed=distributed,
collate_fn=None,
pin_memory=cfg.PIN_MEMORY,
use_multi_epochs_loader=True
)
else:
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
shuffle=shuffle,
num_workers=cfg.WORKERS,
pin_memory=cfg.PIN_MEMORY,
sampler=sampler,
drop_last=True if is_train else False,
)
return data_loader

Просмотреть файл

@ -0,0 +1,45 @@
""" Real labels evaluator for ImageNet
Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
Based on Numpy example at https://github.com/google-research/reassessed-imagenet
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import json
import numpy as np
class RealLabelsImagenet:
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
with open(real_json) as real_labels:
real_labels = json.load(real_labels)
real_labels = {
f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels
for i, labels in enumerate(real_labels)
}
self.real_labels = real_labels
self.filenames = filenames
assert len(self.filenames) == len(self.real_labels)
self.topk = topk
self.is_correct = {k: [] for k in topk}
self.sample_idx = 0
def add_result(self, output):
maxk = max(self.topk)
_, pred_batch = output.topk(maxk, 1, True, True)
pred_batch = pred_batch.cpu().numpy()
for pred in pred_batch:
filename = self.filenames[self.sample_idx]
filename = os.path.basename(filename)
if self.real_labels[filename]:
for k in self.topk:
self.is_correct[k].append(
any([p in self.real_labels[filename] for p in pred[:k]]))
self.sample_idx += 1
def get_accuracy(self, k=None):
if k is None:
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
else:
return float(np.mean(self.is_correct[k])) * 100

Просмотреть файл

@ -0,0 +1 @@
from .ra_sampler import RASampler

Просмотреть файл

@ -0,0 +1,63 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import torch.distributed as dist
import math
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch

Просмотреть файл

@ -0,0 +1 @@
from .build import build_transforms

Просмотреть файл

@ -0,0 +1,124 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from timm.data import create_transform
from PIL import ImageFilter
import logging
import random
import torchvision.transforms as T
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[.1, 2.]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
def get_resolution(original_resolution):
"""Takes (H,W) and returns (precrop, crop)."""
area = original_resolution[0] * original_resolution[1]
return (160, 128) if area < 96*96 else (512, 480)
def build_transforms(cfg, is_train=True):
if cfg.AUG.TIMM_AUG.USE_TRANSFORM and is_train:
logging.info('=> use timm transform for training')
timm_cfg = cfg.AUG.TIMM_AUG
transforms = create_transform(
input_size=cfg.TRAIN.IMAGE_SIZE[0],
is_training=True,
use_prefetcher=False,
no_aug=False,
re_prob=timm_cfg.RE_PROB,
re_mode=timm_cfg.RE_MODE,
re_count=timm_cfg.RE_COUNT,
scale=cfg.AUG.SCALE,
ratio=cfg.AUG.RATIO,
hflip=timm_cfg.HFLIP,
vflip=timm_cfg.VFLIP,
color_jitter=timm_cfg.COLOR_JITTER,
auto_augment=timm_cfg.AUTO_AUGMENT,
interpolation=timm_cfg.INTERPOLATION,
mean=cfg.INPUT.MEAN,
std=cfg.INPUT.STD,
)
return transforms
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
normalize = T.Normalize(mean=cfg.INPUT.MEAN, std=cfg.INPUT.STD)
transforms = None
if is_train:
if cfg.FINETUNE.FINETUNE and not cfg.FINETUNE.USE_TRAIN_AUG:
# precrop, crop = get_resolution(cfg.TRAIN.IMAGE_SIZE)
crop = cfg.TRAIN.IMAGE_SIZE[0]
precrop = crop + 32
transforms = T.Compose([
T.Resize(
(precrop, precrop),
interpolation=cfg.AUG.INTERPOLATION
),
T.RandomCrop((crop, crop)),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
else:
aug = cfg.AUG
scale = aug.SCALE
ratio = aug.RATIO
ts = [
T.RandomResizedCrop(
cfg.TRAIN.IMAGE_SIZE[0], scale=scale, ratio=ratio,
interpolation=cfg.AUG.INTERPOLATION
),
T.RandomHorizontalFlip(),
]
cj = aug.COLOR_JITTER
if cj[-1] > 0.0:
ts.append(T.RandomApply([T.ColorJitter(*cj[:-1])], p=cj[-1]))
gs = aug.GRAY_SCALE
if gs > 0.0:
ts.append(T.RandomGrayscale(gs))
gb = aug.GAUSSIAN_BLUR
if gb > 0.0:
ts.append(T.RandomApply([GaussianBlur([.1, 2.])], p=gb))
ts.append(T.ToTensor())
ts.append(normalize)
transforms = T.Compose(ts)
else:
if cfg.TEST.CENTER_CROP:
transforms = T.Compose([
T.Resize(
int(cfg.TEST.IMAGE_SIZE[0] / 0.875),
interpolation=cfg.TEST.INTERPOLATION
),
T.CenterCrop(cfg.TEST.IMAGE_SIZE[0]),
T.ToTensor(),
normalize,
])
else:
transforms = T.Compose([
T.Resize(
(cfg.TEST.IMAGE_SIZE[1], cfg.TEST.IMAGE_SIZE[0]),
interpolation=cfg.TEST.INTERPOLATION
),
T.ToTensor(),
normalize,
])
return transforms

9
lib/models/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .cls_cvt import *
from .registry import *
from .build import build_model

10
lib/models/build.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
from .registry import model_entrypoints
from .registry import is_model
def build_model(config, **kwargs):
model_name = config.MODEL.NAME
if not is_model(model_name):
raise ValueError(f'Unkown model: {model_name}')
return model_entrypoints(model_name)(config, **kwargs)

645
lib/models/cls_cvt.py Normal file
Просмотреть файл

@ -0,0 +1,645 @@
from functools import partial
from itertools import repeat
from torch._six import container_abcs
import logging
import os
from collections import OrderedDict
import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, trunc_normal_
from .registry import register_model
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim_in,
dim_out,
num_heads,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
method='dw_bn',
kernel_size=3,
stride_kv=1,
stride_q=1,
padding_kv=1,
padding_q=1,
with_cls_token=True,
**kwargs
):
super().__init__()
self.stride_kv = stride_kv
self.stride_q = stride_q
self.dim = dim_out
self.num_heads = num_heads
# head_dim = self.qkv_dim // num_heads
self.scale = dim_out ** -0.5
self.with_cls_token = with_cls_token
self.conv_proj_q = self._build_projection(
dim_in, dim_out, kernel_size, padding_q,
stride_q, 'linear' if method == 'avg' else method
)
self.conv_proj_k = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv, method
)
self.conv_proj_v = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv, method
)
self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim_out, dim_out)
self.proj_drop = nn.Dropout(proj_drop)
def _build_projection(self,
dim_in,
dim_out,
kernel_size,
padding,
stride,
method):
if method == 'dw_bn':
proj = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=False,
groups=dim_in
)),
('bn', nn.BatchNorm2d(dim_in)),
('rearrage', Rearrange('b c h w -> b (h w) c')),
]))
elif method == 'avg':
proj = nn.Sequential(OrderedDict([
('avg', nn.AvgPool2d(
kernel_size=kernel_size,
padding=padding,
stride=stride,
ceil_mode=True
)),
('rearrage', Rearrange('b c h w -> b (h w) c')),
]))
elif method == 'linear':
proj = None
else:
raise ValueError('Unknown method ({})'.format(method))
return proj
def forward_conv(self, x, h, w):
if self.with_cls_token:
cls_token, x = torch.split(x, [1, h*w], 1)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
if self.conv_proj_q is not None:
q = self.conv_proj_q(x)
else:
q = rearrange(x, 'b c h w -> b (h w) c')
if self.conv_proj_k is not None:
k = self.conv_proj_k(x)
else:
k = rearrange(x, 'b c h w -> b (h w) c')
if self.conv_proj_v is not None:
v = self.conv_proj_v(x)
else:
v = rearrange(x, 'b c h w -> b (h w) c')
if self.with_cls_token:
q = torch.cat((cls_token, q), dim=1)
k = torch.cat((cls_token, k), dim=1)
v = torch.cat((cls_token, v), dim=1)
return q, k, v
def forward(self, x, h, w):
if (
self.conv_proj_q is not None
or self.conv_proj_k is not None
or self.conv_proj_v is not None
):
q, k, v = self.forward_conv(x, h, w)
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
attn = F.softmax(attn_score, dim=-1)
attn = self.attn_drop(attn)
x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
x = rearrange(x, 'b h t d -> b t (h d)')
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def compute_macs(module, input, output):
# T: num_token
# S: num_token
input = input[0]
flops = 0
_, T, C = input.shape
H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T))
H_Q = H / module.stride_q
W_Q = H / module.stride_q
T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q
H_KV = H / module.stride_kv
W_KV = W / module.stride_kv
T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV
# C = module.dim
# S = T
# Scaled-dot-product macs
# [B x T x C] x [B x C x T] --> [B x T x S]
# multiplication-addition is counted as 1 because operations can be fused
flops += T_Q * T_KV * module.dim
# [B x T x S] x [B x S x C] --> [B x T x C]
flops += T_Q * module.dim * T_KV
if (
hasattr(module, 'conv_proj_q')
and hasattr(module.conv_proj_q, 'conv')
):
params = sum(
[
p.numel()
for p in module.conv_proj_q.conv.parameters()
]
)
flops += params * H_Q * W_Q
if (
hasattr(module, 'conv_proj_k')
and hasattr(module.conv_proj_k, 'conv')
):
params = sum(
[
p.numel()
for p in module.conv_proj_k.conv.parameters()
]
)
flops += params * H_KV * W_KV
if (
hasattr(module, 'conv_proj_v')
and hasattr(module.conv_proj_v, 'conv')
):
params = sum(
[
p.numel()
for p in module.conv_proj_v.conv.parameters()
]
)
flops += params * H_KV * W_KV
params = sum([p.numel() for p in module.proj_q.parameters()])
flops += params * T_Q
params = sum([p.numel() for p in module.proj_k.parameters()])
flops += params * T_KV
params = sum([p.numel() for p in module.proj_v.parameters()])
flops += params * T_KV
params = sum([p.numel() for p in module.proj.parameters()])
flops += params * T
module.__flops__ += flops
class Block(nn.Module):
def __init__(self,
dim_in,
dim_out,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
**kwargs):
super().__init__()
self.with_cls_token = kwargs['with_cls_token']
self.norm1 = norm_layer(dim_in)
self.attn = Attention(
dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
**kwargs
)
self.drop_path = DropPath(drop_path) \
if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim_out)
dim_mlp_hidden = int(dim_out * mlp_ratio)
self.mlp = Mlp(
in_features=dim_out,
hidden_features=dim_mlp_hidden,
act_layer=act_layer,
drop=drop
)
def forward(self, x, h, w):
res = x
x = self.norm1(x)
attn = self.attn(x, h, w)
x = res + self.drop_path(attn)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ConvEmbed(nn.Module):
""" Image to Conv Embedding
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm:
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
return x
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
patch_size=16,
patch_stride=16,
patch_padding=0,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
init='trunc_norm',
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.rearrage = None
self.patch_embed = ConvEmbed(
# img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
stride=patch_stride,
padding=patch_padding,
embed_dim=embed_dim,
norm_layer=norm_layer
)
with_cls_token = kwargs['with_cls_token']
if with_cls_token:
self.cls_token = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
else:
self.cls_token = None
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
blocks = []
for j in range(depth):
blocks.append(
Block(
dim_in=embed_dim,
dim_out=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[j],
act_layer=act_layer,
norm_layer=norm_layer,
**kwargs
)
)
self.blocks = nn.ModuleList(blocks)
if self.cls_token is not None:
trunc_normal_(self.cls_token, std=.02)
if init == 'xavier':
self.apply(self._init_weights_xavier)
else:
self.apply(self._init_weights_trunc_normal)
def _init_weights_trunc_normal(self, m):
if isinstance(m, nn.Linear):
logging.info('=> init weight of Linear from trunc norm')
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
logging.info('=> init bias of Linear to zeros')
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_weights_xavier(self, m):
if isinstance(m, nn.Linear):
logging.info('=> init weight of Linear from xavier uniform')
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
logging.info('=> init bias of Linear to zeros')
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.size()
x = rearrange(x, 'b c h w -> b (h w) c')
cls_tokens = None
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks):
x = blk(x, H, W)
if self.cls_token is not None:
cls_tokens, x = torch.split(x, [1, H*W], 1)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
return x, cls_tokens
class ConvolutionalVisionTransformer(nn.Module):
def __init__(self,
in_chans=3,
num_classes=1000,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
init='trunc_norm',
spec=None):
super().__init__()
self.num_classes = num_classes
self.num_stages = spec['NUM_STAGES']
for i in range(self.num_stages):
kwargs = {
'patch_size': spec['PATCH_SIZE'][i],
'patch_stride': spec['PATCH_STRIDE'][i],
'patch_padding': spec['PATCH_PADDING'][i],
'embed_dim': spec['DIM_EMBED'][i],
'depth': spec['DEPTH'][i],
'num_heads': spec['NUM_HEADS'][i],
'mlp_ratio': spec['MLP_RATIO'][i],
'qkv_bias': spec['QKV_BIAS'][i],
'drop_rate': spec['DROP_RATE'][i],
'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
'drop_path_rate': spec['DROP_PATH_RATE'][i],
'with_cls_token': spec['CLS_TOKEN'][i],
'method': spec['QKV_PROJ_METHOD'][i],
'kernel_size': spec['KERNEL_QKV'][i],
'padding_q': spec['PADDING_Q'][i],
'padding_kv': spec['PADDING_KV'][i],
'stride_kv': spec['STRIDE_KV'][i],
'stride_q': spec['STRIDE_Q'][i],
}
stage = VisionTransformer(
in_chans=in_chans,
init=init,
act_layer=act_layer,
norm_layer=norm_layer,
**kwargs
)
setattr(self, f'stage{i}', stage)
in_chans = spec['DIM_EMBED'][i]
dim_embed = spec['DIM_EMBED'][-1]
self.norm = norm_layer(dim_embed)
self.cls_token = spec['CLS_TOKEN'][-1]
# Classifier head
self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.head.weight, std=0.02)
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] is '*'
)
if need_init:
if verbose:
logging.info(f'=> init {k} from {pretrained}')
if 'pos_embed' in k and v.size() != model_dict[k].size():
size_pretrained = v.size()
size_new = model_dict[k].size()
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format(size_pretrained, size_new)
)
ntok_new = size_new[1]
ntok_new -= 1
posemb_tok, posemb_grid = v[:, :1], v[0, 1:]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
logging.info(
'=> load_pretrained: grid-size from {} to {}'
.format(gs_old, gs_new)
)
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = scipy.ndimage.zoom(
posemb_grid, zoom, order=1
)
posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
v = torch.tensor(
np.concatenate([posemb_tok, posemb_grid], axis=1)
)
need_init_state_dict[k] = v
self.load_state_dict(need_init_state_dict, strict=False)
@torch.jit.ignore
def no_weight_decay(self):
layers = set()
for i in range(self.num_stages):
layers.add(f'stage{i}.pos_embed')
layers.add(f'stage{i}.cls_token')
return layers
def forward_features(self, x):
for i in range(self.num_stages):
x, cls_tokens = getattr(self, f'stage{i}')(x)
if self.cls_token:
x = self.norm(cls_tokens)
x = torch.squeeze(x)
else:
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.norm(x)
x = torch.mean(x, dim=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def get_cls_model(config, **kwargs):
msvit_spec = config.MODEL.SPEC
msvit = ConvolutionalVisionTransformer(
in_chans=3,
num_classes=config.MODEL.NUM_CLASSES,
act_layer=QuickGELU,
norm_layer=partial(LayerNorm, eps=1e-5),
init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
spec=msvit_spec
)
if config.MODEL.INIT_WEIGHTS:
msvit.init_weights(
config.MODEL.PRETRAINED,
config.MODEL.PRETRAINED_LAYERS,
config.VERBOSE
)
return msvit

18
lib/models/registry.py Normal file
Просмотреть файл

@ -0,0 +1,18 @@
_model_entrypoints = {}
def register_model(fn):
module_name_split = fn.__module__.split('.')
model_name = module_name_split[-1]
_model_entrypoints[model_name] = fn
return fn
def model_entrypoints(model_name):
return _model_entrypoints[model_name]
def is_model(model_name):
return model_name in _model_entrypoints

1
lib/optim/__init__.py Normal file
Просмотреть файл

@ -0,0 +1 @@
from .build import build_optimizer

155
lib/optim/build.py Normal file
Просмотреть файл

@ -0,0 +1,155 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn as nn
import torch.optim as optim
from timm.optim import create_optimizer
def _is_depthwise(m):
return (
isinstance(m, nn.Conv2d)
and m.groups == m.in_channels
and m.groups == m.out_channels
)
def set_wd(cfg, model):
without_decay_list = cfg.TRAIN.WITHOUT_WD_LIST
without_decay_depthwise = []
without_decay_norm = []
for m in model.modules():
if _is_depthwise(m) and 'dw' in without_decay_list:
without_decay_depthwise.append(m.weight)
elif isinstance(m, nn.BatchNorm2d) and 'bn' in without_decay_list:
without_decay_norm.append(m.weight)
without_decay_norm.append(m.bias)
elif isinstance(m, nn.GroupNorm) and 'gn' in without_decay_list:
without_decay_norm.append(m.weight)
without_decay_norm.append(m.bias)
elif isinstance(m, nn.LayerNorm) and 'ln' in without_decay_list:
without_decay_norm.append(m.weight)
without_decay_norm.append(m.bias)
with_decay = []
without_decay = []
skip = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
skip_keys = {}
if hasattr(model, 'no_weight_decay_keywords'):
skip_keys = model.no_weight_decay_keywords()
for n, p in model.named_parameters():
ever_set = False
if p.requires_grad is False:
continue
skip_flag = False
if n in skip:
print('=> set {} wd to 0'.format(n))
without_decay.append(p)
skip_flag = True
else:
for i in skip:
if i in n:
print('=> set {} wd to 0'.format(n))
without_decay.append(p)
skip_flag = True
if skip_flag:
continue
for i in skip_keys:
if i in n:
print('=> set {} wd to 0'.format(n))
if skip_flag:
continue
for pp in without_decay_depthwise:
if p is pp:
if cfg.VERBOSE:
print('=> set depthwise({}) wd to 0'.format(n))
without_decay.append(p)
ever_set = True
break
for pp in without_decay_norm:
if p is pp:
if cfg.VERBOSE:
print('=> set norm({}) wd to 0'.format(n))
without_decay.append(p)
ever_set = True
break
if (
(not ever_set)
and 'bias' in without_decay_list
and n.endswith('.bias')
):
if cfg.VERBOSE:
print('=> set bias({}) wd to 0'.format(n))
without_decay.append(p)
elif not ever_set:
with_decay.append(p)
# assert (len(with_decay) + len(without_decay) == len(list(model.parameters())))
params = [
{'params': with_decay},
{'params': without_decay, 'weight_decay': 0.}
]
return params
def build_optimizer(cfg, model):
if cfg.TRAIN.OPTIMIZER == 'timm':
args = cfg.TRAIN.OPTIMIZER_ARGS
print(f'=> usage timm optimizer args: {cfg.TRAIN.OPTIMIZER_ARGS}')
optimizer = create_optimizer(args, model)
return optimizer
optimizer = None
params = set_wd(cfg, model)
if cfg.TRAIN.OPTIMIZER == 'sgd':
optimizer = optim.SGD(
params,
# filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.TRAIN.LR,
momentum=cfg.TRAIN.MOMENTUM,
weight_decay=cfg.TRAIN.WD,
nesterov=cfg.TRAIN.NESTEROV
)
elif cfg.TRAIN.OPTIMIZER == 'adam':
optimizer = optim.Adam(
params,
# filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.TRAIN.LR,
weight_decay=cfg.TRAIN.WD,
)
elif cfg.TRAIN.OPTIMIZER == 'adamW':
optimizer = optim.AdamW(
params,
lr=cfg.TRAIN.LR,
weight_decay=cfg.TRAIN.WD,
)
elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
optimizer = optim.RMSprop(
params,
# filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.TRAIN.LR,
momentum=cfg.TRAIN.MOMENTUM,
weight_decay=cfg.TRAIN.WD,
alpha=cfg.TRAIN.RMSPROP_ALPHA,
centered=cfg.TRAIN.RMSPROP_CENTERED
)
return optimizer

Просмотреть файл

@ -0,0 +1 @@
from .build import build_lr_scheduler

41
lib/scheduler/build.py Normal file
Просмотреть файл

@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from timm.scheduler import create_scheduler
def build_lr_scheduler(cfg, optimizer, begin_epoch):
if 'METHOD' not in cfg.TRAIN.LR_SCHEDULER:
raise ValueError('Please set TRAIN.LR_SCHEDULER.METHOD!')
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'MultiStep':
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
cfg.TRAIN.LR_SCHEDULER.MILESTONES,
cfg.TRAIN.LR_SCHEDULER.GAMMA,
begin_epoch - 1)
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'CosineAnnealing':
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
cfg.TRAIN.END_EPOCH,
cfg.TRAIN.LR_SCHEDULER.ETA_MIN,
begin_epoch - 1
)
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'CyclicLR':
lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
optimizer,
base_lr=cfg.TRAIN.LR_SCHEDULER.BASE_LR,
max_LR=cfg.TRAIN.LR_SCHEDULER.MAX_LR,
step_size_up=cfg.TRAIN.LR_SCHEDULER.STEP_SIZE_UP
)
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'timm':
args = cfg.TRAIN.LR_SCHEDULER.ARGS
lr_scheduler, _ = create_scheduler(args, optimizer)
lr_scheduler.step(begin_epoch)
else:
raise ValueError('Unknown lr scheduler: {}'.format(
cfg.TRAIN.LR_SCHEDULER.METHOD))
return lr_scheduler

133
lib/utils/comm.py Normal file
Просмотреть файл

@ -0,0 +1,133 @@
import pickle
import torch
import torch.distributed as dist
class Comm(object):
def __init__(self, local_rank=0):
self.local_rank = 0
@property
def world_size(self):
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
@property
def rank(self):
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
@property
def local_rank(self):
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return self._local_rank
@local_rank.setter
def local_rank(self, value):
if not dist.is_available():
self._local_rank = 0
if not dist.is_initialized():
self._local_rank = 0
self._local_rank = value
@property
def head(self):
return 'Rank[{}/{}]'.format(self.rank, self.world_size)
def is_main_process(self):
return self.rank == 0
def synchronize(self):
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if self.world_size == 1:
return
dist.barrier()
comm = Comm()
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = comm.world_size
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = comm.world_size
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

213
lib/utils/utils.py Normal file
Просмотреть файл

@ -0,0 +1,213 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import timedelta
from pathlib import Path
import os
import logging
import shutil
import time
import tensorwatch as tw
import torch
import torch.backends.cudnn as cudnn
from utils.comm import comm
from ptflops import get_model_complexity_info
def setup_logger(final_output_dir, rank, phase):
time_str = time.strftime('%Y-%m-%d-%H-%M')
log_file = '{}_{}_rank{}.txt'.format(phase, time_str, rank)
final_log_file = os.path.join(final_output_dir, log_file)
head = '%(asctime)-15s:[P:%(process)d]:' + comm.head + ' %(message)s'
logging.basicConfig(
filename=str(final_log_file), format=head
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console = logging.StreamHandler()
console.setFormatter(
logging.Formatter(head)
)
logging.getLogger('').addHandler(console)
def create_logger(cfg, cfg_name, phase='train'):
root_output_dir = Path(cfg.OUTPUT_DIR)
dataset = cfg.DATASET.DATASET
cfg_name = cfg.NAME
final_output_dir = root_output_dir / dataset / cfg_name
print('=> creating {} ...'.format(root_output_dir))
root_output_dir.mkdir(parents=True, exist_ok=True)
print('=> creating {} ...'.format(final_output_dir))
final_output_dir.mkdir(parents=True, exist_ok=True)
print('=> setup logger ...')
setup_logger(final_output_dir, cfg.RANK, phase)
return str(final_output_dir)
def init_distributed(args):
args.num_gpus = int(os.environ["WORLD_SIZE"]) \
if "WORLD_SIZE" in os.environ else 1
args.distributed = args.num_gpus > 1
if args.distributed:
print("=> init process group start")
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://",
timeout=timedelta(minutes=180))
comm.local_rank = args.local_rank
print("=> init process group end")
def setup_cudnn(config):
cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED
def count_parameters(model):
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params/1000000
def summary_model_on_master(model, config, output_dir, copy):
if comm.is_main_process():
this_dir = os.path.dirname(__file__)
shutil.copy2(
os.path.join(this_dir, '../models', config.MODEL.NAME + '.py'),
output_dir
)
logging.info('=> {}'.format(model))
try:
num_params = count_parameters(model)
logging.info("Trainable Model Total Parameter: \t%2.1fM" % num_params)
except Exception:
logging.error('=> error when counting parameters')
if config.MODEL_SUMMARY:
try:
logging.info('== model_stats by tensorwatch ==')
df = tw.model_stats(
model,
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
df.to_html(os.path.join(output_dir, 'model_summary.html'))
df.to_csv(os.path.join(output_dir, 'model_summary.csv'))
msg = '*'*20 + ' Model summary ' + '*'*20
logging.info(
'\n{msg}\n{summary}\n{msg}'.format(
msg=msg, summary=df.iloc[-1]
)
)
logging.info('== model_stats by tensorwatch ==')
except Exception:
logging.error('=> error when run model_stats')
try:
logging.info('== get_model_complexity_info by ptflops ==')
macs, params = get_model_complexity_info(
model,
(3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]),
as_strings=True, print_per_layer_stat=True, verbose=True
)
logging.info(f'=> FLOPs: {macs:<8}, params: {params:<8}')
logging.info('== get_model_complexity_info by ptflops ==')
except Exception:
logging.error('=> error when run get_model_complexity_info')
def resume_checkpoint(model,
optimizer,
config,
output_dir,
in_epoch):
best_perf = 0.0
begin_epoch_or_step = 0
checkpoint = os.path.join(output_dir, 'checkpoint.pth')\
if not config.TRAIN.CHECKPOINT else config.TRAIN.CHECKPOINT
if config.TRAIN.AUTO_RESUME and os.path.exists(checkpoint):
logging.info(
"=> loading checkpoint '{}'".format(checkpoint)
)
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
best_perf = checkpoint_dict['perf']
begin_epoch_or_step = checkpoint_dict['epoch' if in_epoch else 'step']
state_dict = checkpoint_dict['state_dict']
model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint_dict['optimizer'])
logging.info(
"=> {}: loaded checkpoint '{}' ({}: {})"
.format(comm.head,
checkpoint,
'epoch' if in_epoch else 'step',
begin_epoch_or_step)
)
return best_perf, begin_epoch_or_step
def save_checkpoint_on_master(model,
*,
distributed,
model_name,
optimizer,
output_dir,
in_epoch,
epoch_or_step,
best_perf):
if not comm.is_main_process():
return
states = model.module.state_dict() \
if distributed else model.state_dict()
logging.info('=> saving checkpoint to {}'.format(output_dir))
save_dict = {
'epoch' if in_epoch else 'step': epoch_or_step + 1,
'model': model_name,
'state_dict': states,
'perf': best_perf,
'optimizer': optimizer.state_dict(),
}
try:
torch.save(save_dict, os.path.join(output_dir, 'checkpoint.pth'))
except Exception:
logging.error('=> error when saving checkpoint!')
def save_model_on_master(model, distributed, out_dir, fname):
if not comm.is_main_process():
return
try:
fname_full = os.path.join(out_dir, fname)
logging.info(f'=> save model to {fname_full}')
torch.save(
model.module.state_dict() if distributed else model.state_dict(),
fname_full
)
except Exception:
logging.error('=> error when saving checkpoint!')
def strip_prefix_if_present(state_dict, prefix):
keys = sorted(state_dict.keys())
if not all(key.startswith(prefix) for key in keys):
return state_dict
from collections import OrderedDict
stripped_state_dict = OrderedDict()
for key, value in state_dict.items():
stripped_state_dict[key.replace(prefix, "")] = value
return stripped_state_dict

12
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,12 @@
opencv-python
pyyaml
json_tricks
yacs
scikit-learn
tensorwatch
tensorboardX
pandas
ptflops
timm==0.3.2
numpy==1.19.3
einops

87
run.sh Normal file
Просмотреть файл

@ -0,0 +1,87 @@
#!/bin/bash
train() {
python3 -m torch.distributed.launch \
--nnodes ${NODE_COUNT} \
--node_rank ${RANK} \
--master_addr ${MASTER_ADDR} \
--master_port ${MASTER_PORT} \
--nproc_per_node ${GPUS} \
tools/train.py ${EXTRA_ARGS}
}
test() {
python3 -m torch.distributed.launch \
--nnodes ${NODE_COUNT} \
--node_rank ${RANK} \
--master_addr ${MASTER_ADDR} \
--master_port ${MASTER_PORT} \
--nproc_per_node ${GPUS} \
tools/test.py ${EXTRA_ARGS}
}
############################ Main #############################
GPUS=`nvidia-smi -L | wc -l`
MASTER_PORT=9000
INSTALL_DEPS=false
while [[ $# -gt 0 ]]
do
key="$1"
case $key in
-h|--help)
echo "Usage: $0 [run_options]"
echo "Options:"
echo " -g|--gpus <1> - number of gpus to be used"
echo " -t|--job-type <train> - job type (train|io|bit_finetune|test)"
echo " -p|--port <9000> - master port"
echo " -i|--install-deps - If install dependencies (default: False)"
exit 1
;;
-g|--gpus)
GPUS=$2
shift
;;
-t|--job-type)
JOB_TYPE=$2
shift
;;
-p|--port)
MASTER_PORT=$2
shift
;;
-i|--install-deps)
INSTALL_DEPS=true
;;
*)
EXTRA_ARGS="$EXTRA_ARGS $1"
;;
esac
shift
done
if $INSTALL_DEPS; then
python -m pip install -r requirements.txt --user -q
fi
RANK=0
MASTER_ADDR=127.0.0.1
NODE_COUNT=1
echo "job type: ${JOB_TYPE}"
echo "rank: ${RANK}"
echo "node count: ${NODE_COUNT}"
echo "master addr: ${MASTER_ADDR}"
case $JOB_TYPE in
train)
train
;;
test)
test
;;
*)
echo "unknown job type"
;;
esac

18
tools/_init_paths.py Normal file
Просмотреть файл

@ -0,0 +1,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path as osp
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = osp.dirname(__file__)
lib_path = osp.join(this_dir, '..', 'lib')
add_path(lib_path)
add_path(osp.join(this_dir, '..'))

144
tools/test.py Normal file
Просмотреть файл

@ -0,0 +1,144 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import logging
import os
import pickle as pkl
import pprint
import time
import torch
import torch.nn.parallel
import torch.optim
from torch.utils.collect_env import get_pretty_env_info
from tensorboardX import SummaryWriter
import _init_paths
from config import config
from config import update_config
from core.function import test
from core.loss import build_criterion
from dataset import build_dataloader
from dataset import RealLabelsImagenet
from models import build_model
from utils.comm import comm
from utils.utils import create_logger
from utils.utils import init_distributed
from utils.utils import setup_cudnn
from utils.utils import summary_model_on_master
from utils.utils import strip_prefix_if_present
def parse_args():
parser = argparse.ArgumentParser(
description='Test classification network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
# distributed training
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--port", type=int, default=9000)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
def main():
args = parse_args()
init_distributed(args)
setup_cudnn(config)
update_config(config, args)
final_output_dir = create_logger(config, args.cfg, 'test')
tb_log_dir = final_output_dir
if comm.is_main_process():
logging.info("=> collecting env info (might take some time)")
logging.info("\n" + get_pretty_env_info())
logging.info(pprint.pformat(args))
logging.info(config)
logging.info("=> using {} GPUs".format(args.num_gpus))
output_config_path = os.path.join(final_output_dir, 'config.yaml')
logging.info("=> saving config into: {}".format(output_config_path))
model = build_model(config)
model.to(torch.device('cuda'))
model_file = config.TEST.MODEL_FILE if config.TEST.MODEL_FILE \
else os.path.join(final_output_dir, 'model_best.pth')
logging.info('=> load model file: {}'.format(model_file))
ext = model_file.split('.')[-1]
if ext == 'pth':
state_dict = torch.load(model_file, map_location="cpu")
else:
raise ValueError("Unknown model file")
model.load_state_dict(state_dict, strict=False)
model.to(torch.device('cuda'))
writer_dict = {
'writer': SummaryWriter(logdir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
summary_model_on_master(model, config, final_output_dir, False)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank
)
# define loss function (criterion) and optimizer
criterion = build_criterion(config, train=False)
criterion.cuda()
valid_loader = build_dataloader(config, False, args.distributed)
real_labels = None
if (
config.DATASET.DATASET == 'imagenet'
and config.DATASET.DATA_FORMAT == 'tsv'
and config.TEST.REAL_LABELS
):
filenames = valid_loader.dataset.get_filenames()
real_json = os.path.join(config.DATASET.ROOT, 'real.json')
logging.info('=> loading real labels...')
real_labels = RealLabelsImagenet(filenames, real_json)
valid_labels = None
if config.TEST.VALID_LABELS:
with open(config.TEST.VALID_LABELS, 'r') as f:
valid_labels = {
int(line.rstrip()) for line in f
}
valid_labels = [
i in valid_labels for i in range(config.MODEL.NUM_CLASSES)
]
logging.info('=> start testing')
start = time.time()
test(config, valid_loader, model, criterion,
final_output_dir, tb_log_dir, writer_dict,
args.distributed, real_labels=real_labels,
valid_labels=valid_labels)
logging.info('=> test duration time: {:.2f}s'.format(time.time()-start))
writer_dict['writer'].close()
logging.info('=> finish testing')
if __name__ == '__main__':
main()

212
tools/train.py Normal file
Просмотреть файл

@ -0,0 +1,212 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import logging
import os
import pprint
import time
import torch
import torch.nn.parallel
import torch.optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.utils.collect_env import get_pretty_env_info
from tensorboardX import SummaryWriter
import _init_paths
from config import config
from config import update_config
from config import save_config
from core.loss import build_criterion
from core.function import train_one_epoch, test
from dataset import build_dataloader
from models import build_model
from optim import build_optimizer
from scheduler import build_lr_scheduler
from utils.comm import comm
from utils.utils import create_logger
from utils.utils import init_distributed
from utils.utils import setup_cudnn
from utils.utils import summary_model_on_master
from utils.utils import resume_checkpoint
from utils.utils import save_checkpoint_on_master
from utils.utils import save_model_on_master
def parse_args():
parser = argparse.ArgumentParser(
description='Train classification network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
# distributed training
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--port", type=int, default=9000)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
def main():
args = parse_args()
init_distributed(args)
setup_cudnn(config)
update_config(config, args)
final_output_dir = create_logger(config, args.cfg, 'train')
tb_log_dir = final_output_dir
if comm.is_main_process():
logging.info("=> collecting env info (might take some time)")
logging.info("\n" + get_pretty_env_info())
logging.info(pprint.pformat(args))
logging.info(config)
logging.info("=> using {} GPUs".format(args.num_gpus))
output_config_path = os.path.join(final_output_dir, 'config.yaml')
logging.info("=> saving config into: {}".format(output_config_path))
save_config(config, output_config_path)
model = build_model(config)
model.to(torch.device('cuda'))
# copy model file
summary_model_on_master(model, config, final_output_dir, True)
if config.AMP.ENABLED and config.AMP.MEMORY_FORMAT == 'nhwc':
logging.info('=> convert memory format to nhwc')
model.to(memory_format=torch.channels_last)
writer_dict = {
'writer': SummaryWriter(logdir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
best_perf = 0.0
best_model = True
begin_epoch = config.TRAIN.BEGIN_EPOCH
optimizer = build_optimizer(config, model)
best_perf, begin_epoch = resume_checkpoint(
model, optimizer, config, final_output_dir, True
)
train_loader = build_dataloader(config, True, args.distributed)
valid_loader = build_dataloader(config, False, args.distributed)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True
)
criterion = build_criterion(config)
criterion.cuda()
criterion_eval = build_criterion(config, train=False)
criterion_eval.cuda()
lr_scheduler = build_lr_scheduler(config, optimizer, begin_epoch)
scaler = torch.cuda.amp.GradScaler(enabled=config.AMP.ENABLED)
logging.info('=> start training')
for epoch in range(begin_epoch, config.TRAIN.END_EPOCH):
head = 'Epoch[{}]:'.format(epoch)
logging.info('=> {} epoch start'.format(head))
start = time.time()
if args.distributed:
train_loader.sampler.set_epoch(epoch)
# train for one epoch
logging.info('=> {} train start'.format(head))
with torch.autograd.set_detect_anomaly(config.TRAIN.DETECT_ANOMALY):
train_one_epoch(config, train_loader, model, criterion, optimizer,
epoch, final_output_dir, tb_log_dir, writer_dict,
scaler=scaler)
logging.info(
'=> {} train end, duration: {:.2f}s'
.format(head, time.time()-start)
)
# evaluate on validation set
logging.info('=> {} validate start'.format(head))
val_start = time.time()
if epoch >= config.TRAIN.EVAL_BEGIN_EPOCH:
perf = test(
config, valid_loader, model, criterion_eval,
final_output_dir, tb_log_dir, writer_dict,
args.distributed
)
best_model = (perf > best_perf)
best_perf = perf if best_model else best_perf
logging.info(
'=> {} validate end, duration: {:.2f}s'
.format(head, time.time()-val_start)
)
lr_scheduler.step(epoch=epoch+1)
if config.TRAIN.LR_SCHEDULER.METHOD == 'timm':
lr = lr_scheduler.get_epoch_values(epoch+1)[0]
else:
lr = lr_scheduler.get_last_lr()[0]
logging.info(f'=> lr: {lr}')
save_checkpoint_on_master(
model=model,
distributed=args.distributed,
model_name=config.MODEL.NAME,
optimizer=optimizer,
output_dir=final_output_dir,
in_epoch=True,
epoch_or_step=epoch,
best_perf=best_perf,
)
if best_model and comm.is_main_process():
save_model_on_master(
model, args.distributed, final_output_dir, 'model_best.pth'
)
if config.TRAIN.SAVE_ALL_MODELS and comm.is_main_process():
save_model_on_master(
model, args.distributed, final_output_dir, f'model_{epoch}.pth'
)
logging.info(
'=> {} epoch end, duration : {:.2f}s'
.format(head, time.time()-start)
)
save_model_on_master(
model, args.distributed, final_output_dir, 'final_state.pth'
)
if config.SWA.ENABLED and comm.is_main_process():
save_model_on_master(
args.distributed, final_output_dir, 'swa_state.pth'
)
writer_dict['writer'].close()
logging.info('=> finish training')
if __name__ == '__main__':
main()