зеркало из https://github.com/microsoft/CvT.git
This commit is contained in:
Родитель
f42d58b109
Коммит
56984edeed
130
README.md
130
README.md
|
@ -1,15 +1,129 @@
|
|||
# Project
|
||||
# Introduction
|
||||
This is an official implementation of [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808). We present a new architecture, named Convolutional vision Transformers (CvT), that improves Vision Transformers (ViT) in performance and efficienty by introducing convolutions into ViT to yield the best of both disignes. This is accomplished through two primary modifications: a hierarchy of Transformers containing a new convolutional token embedding, and a convolutional Transformer block leveraging a convolutional projection. These changes introduce desirable properties of convolutional neural networks (CNNs) to the ViT architecture (e.g. shift, scale, and distortion invariance) while maintaining the merits of Transformers (e.g. dynamic attention, global context, and better generalization). We validate CvT by conducting extensive experiments, showing that this approach achieves state-of-the-art performance over other Vision Transformers and ResNets on ImageNet-1k, with fewer parameters and lower FLOPs. In addition, performance gains are maintained when pretrained on larger dataset (e.g. ImageNet-22k) and fine-tuned to downstream tasks. Pre-trained on ImageNet-22k, our CvT-W24 obtains a top-1 accuracy of 87.7% on the ImageNet-1k val set. Finally, our results show that the positional encoding, a crucial component in existing Vision Transformers, can be safely removed in our model, simplifying the design for higher resolution vision tasks.
|
||||
|
||||
> This repo has been populated by an initial template to help get you started. Please
|
||||
> make sure to update the content to build a great experience for community-building.
|
||||
![](figures/pipeline.svg)
|
||||
|
||||
As the maintainer of this project, please make a few updates:
|
||||
# Main results
|
||||
## Models pre-trained on ImageNet-1k
|
||||
| Model | Resolution | Param | GFLOPs | Top-1 |
|
||||
|--------|------------|-------|--------|-------|
|
||||
| CvT-13 | 224x224 | 20M | 4.5 | 81.6 |
|
||||
| CvT-21 | 224x224 | 32M | 7.1 | 82.5 |
|
||||
| CvT-13 | 384x384 | 20M | 16.3 | 83.0 |
|
||||
| CvT-32 | 384x384 | 32M | 24.9 | 83.3 |
|
||||
|
||||
- Improving this README.MD file to provide a great experience
|
||||
- Updating SUPPORT.MD with content about this project's support experience
|
||||
- Understanding the security reporting process in SECURITY.MD
|
||||
- Remove this section from the README
|
||||
## Models pre-trained on ImageNet-22k
|
||||
| Model | Resolution | Param | GFLOPs | Top-1 |
|
||||
|---------|------------|-------|--------|-------|
|
||||
| CvT-13 | 384x384 | 20M | 16.3 | 83.3 |
|
||||
| CvT-32 | 384x384 | 32M | 24.9 | 84.9 |
|
||||
| CvT-W24 | 384x384 | 277M | 193.2 | 87.6 |
|
||||
|
||||
You can download all the models from our [model zoo](https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al).
|
||||
|
||||
|
||||
# Quick start
|
||||
## Installation
|
||||
Assuming that you have installed PyTroch and TorchVision, if not, please follow the [officiall instruction](https://pytorch.org/) to install them firstly.
|
||||
Intall the dependencies using cmd:
|
||||
|
||||
``` sh
|
||||
python -m pip install -r requirements.txt --user -q
|
||||
```
|
||||
|
||||
The code is developed and tested using pytorch 1.7.1. Other versions of pytorch are not fully tested.
|
||||
|
||||
## Data preparation
|
||||
Please prepare the data as following:
|
||||
|
||||
``` sh
|
||||
|-DATASET
|
||||
|-imagenet
|
||||
|-train
|
||||
| |-class1
|
||||
| | |-img1.jpg
|
||||
| | |-img2.jpg
|
||||
| | |-...
|
||||
| |-class2
|
||||
| | |-img3.jpg
|
||||
| | |-...
|
||||
| |-class3
|
||||
| | |-img4.jpg
|
||||
| | |-...
|
||||
| |-...
|
||||
|-val
|
||||
|-class1
|
||||
| |-img5.jpg
|
||||
| |-...
|
||||
|-class2
|
||||
| |-img6.jpg
|
||||
| |-...
|
||||
|-class3
|
||||
| |-img7.jpg
|
||||
| |-...
|
||||
|-...
|
||||
```
|
||||
|
||||
|
||||
## Run
|
||||
Each experiment is defined by a yaml config file, which is saved under the directory of `experiments`. The directory of `experiments` has a tree structure like this:
|
||||
|
||||
``` sh
|
||||
experiments
|
||||
|-{DATASET_A}
|
||||
| |-{ARCH_A}
|
||||
| |-{ARCH_B}
|
||||
|-{DATASET_B}
|
||||
| |-{ARCH_A}
|
||||
| |-{ARCH_B}
|
||||
|-{DATASET_C}
|
||||
| |-{ARCH_A}
|
||||
| |-{ARCH_B}
|
||||
|-...
|
||||
```
|
||||
|
||||
We provide a `run.sh` script for running jobs in local machine.
|
||||
|
||||
``` sh
|
||||
Usage: run.sh [run_options]
|
||||
Options:
|
||||
-g|--gpus <1> - number of gpus to be used
|
||||
-t|--job-type <aml> - job type (train|test)
|
||||
-p|--port <9000> - master port
|
||||
-i|--install-deps - If install dependencies (default: False)
|
||||
```
|
||||
|
||||
### Training on local machine
|
||||
|
||||
``` sh
|
||||
bash run.sh -g 8 -t train --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml
|
||||
```
|
||||
|
||||
You can also modify the config paramters by the command line. For example, if you want to change the lr rate to 0.1, you can run the command:
|
||||
``` sh
|
||||
bash run.sh -g 8 -t train --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml TRAIN.LR 0.1
|
||||
```
|
||||
|
||||
Notes:
|
||||
- The checkpoint, model, and log files will be saved in OUTPUT/{dataset}/{training config} by default.
|
||||
|
||||
### Testing pre-trained models
|
||||
|
||||
``` sh
|
||||
bash run.sh -t test --cfg experiments/imagenet/cvt/cvt-13-224x224.yaml TEST.MODEL_FILE ${PRETRAINED_MODLE_FILE}
|
||||
```
|
||||
|
||||
# Citation
|
||||
If you find this work or code is helpful in your research, please cite:
|
||||
|
||||
```
|
||||
@article{wu2021cvt,
|
||||
title={Cvt: Introducing convolutions to vision transformers},
|
||||
author={Wu, Haiping and Xiao, Bin and Codella, Noel and Liu, Mengchen and Dai, Xiyang and Yuan, Lu and Zhang, Lei},
|
||||
journal={arXiv preprint arXiv:2103.15808},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
|
|
25
SUPPORT.md
25
SUPPORT.md
|
@ -1,25 +0,0 @@
|
|||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
|
||||
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
|
@ -0,0 +1,83 @@
|
|||
OUTPUT_DIR: 'OUTPUT/'
|
||||
WORKERS: 6
|
||||
PRINT_FREQ: 500
|
||||
AMP:
|
||||
ENABLED: true
|
||||
|
||||
MODEL:
|
||||
NAME: cls_cvt
|
||||
SPEC:
|
||||
INIT: 'trunc_norm'
|
||||
NUM_STAGES: 3
|
||||
PATCH_SIZE: [7, 3, 3]
|
||||
PATCH_STRIDE: [4, 2, 2]
|
||||
PATCH_PADDING: [2, 1, 1]
|
||||
DIM_EMBED: [64, 192, 384]
|
||||
NUM_HEADS: [1, 3, 6]
|
||||
DEPTH: [1, 2, 10]
|
||||
MLP_RATIO: [4.0, 4.0, 4.0]
|
||||
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_PATH_RATE: [0.0, 0.0, 0.1]
|
||||
QKV_BIAS: [True, True, True]
|
||||
CLS_TOKEN: [False, False, True]
|
||||
POS_EMBED: [False, False, False]
|
||||
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
||||
KERNEL_QKV: [3, 3, 3]
|
||||
PADDING_KV: [1, 1, 1]
|
||||
STRIDE_KV: [2, 2, 2]
|
||||
PADDING_Q: [1, 1, 1]
|
||||
STRIDE_Q: [1, 1, 1]
|
||||
AUG:
|
||||
MIXUP_PROB: 1.0
|
||||
MIXUP: 0.8
|
||||
MIXCUT: 1.0
|
||||
TIMM_AUG:
|
||||
USE_LOADER: true
|
||||
RE_COUNT: 1
|
||||
RE_MODE: pixel
|
||||
RE_SPLIT: false
|
||||
RE_PROB: 0.25
|
||||
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
||||
HFLIP: 0.5
|
||||
VFLIP: 0.0
|
||||
COLOR_JITTER: 0.4
|
||||
INTERPOLATION: bicubic
|
||||
LOSS:
|
||||
LABEL_SMOOTHING: 0.1
|
||||
CUDNN:
|
||||
BENCHMARK: true
|
||||
DETERMINISTIC: false
|
||||
ENABLED: true
|
||||
DATASET:
|
||||
DATASET: 'imagenet'
|
||||
DATA_FORMAT: 'jpg'
|
||||
ROOT: 'DATASET/imagenet/'
|
||||
TEST_SET: 'val'
|
||||
TRAIN_SET: 'train'
|
||||
TEST:
|
||||
BATCH_SIZE_PER_GPU: 32
|
||||
IMAGE_SIZE: [224, 224]
|
||||
MODEL_FILE: ''
|
||||
INTERPOLATION: 3
|
||||
TRAIN:
|
||||
BATCH_SIZE_PER_GPU: 256
|
||||
LR: 0.00025
|
||||
IMAGE_SIZE: [224, 224]
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 300
|
||||
LR_SCHEDULER:
|
||||
METHOD: 'timm'
|
||||
ARGS:
|
||||
sched: 'cosine'
|
||||
warmup_epochs: 5
|
||||
warmup_lr: 0.000001
|
||||
min_lr: 0.00001
|
||||
cooldown_epochs: 10
|
||||
decay_rate: 0.1
|
||||
OPTIMIZER: adamW
|
||||
WD: 0.05
|
||||
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
||||
SHUFFLE: true
|
||||
DEBUG:
|
||||
DEBUG: false
|
|
@ -0,0 +1,84 @@
|
|||
OUTPUT_DIR: 'OUTPUT/'
|
||||
WORKERS: 6
|
||||
PRINT_FREQ: 500
|
||||
AMP:
|
||||
ENABLED: true
|
||||
|
||||
MODEL:
|
||||
NAME: cls_cvt
|
||||
SPEC:
|
||||
INIT: 'trunc_norm'
|
||||
NUM_STAGES: 3
|
||||
PATCH_SIZE: [7, 3, 3]
|
||||
PATCH_STRIDE: [4, 2, 2]
|
||||
PATCH_PADDING: [2, 1, 1]
|
||||
DIM_EMBED: [64, 192, 384]
|
||||
NUM_HEADS: [1, 3, 6]
|
||||
DEPTH: [1, 2, 10]
|
||||
MLP_RATIO: [4.0, 4.0, 4.0]
|
||||
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_PATH_RATE: [0.0, 0.0, 0.1]
|
||||
QKV_BIAS: [True, True, True]
|
||||
CLS_TOKEN: [False, False, True]
|
||||
POS_EMBED: [False, False, False]
|
||||
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
||||
KERNEL_QKV: [3, 3, 3]
|
||||
PADDING_KV: [1, 1, 1]
|
||||
STRIDE_KV: [2, 2, 2]
|
||||
PADDING_Q: [1, 1, 1]
|
||||
STRIDE_Q: [1, 1, 1]
|
||||
AUG:
|
||||
MIXUP_PROB: 1.0
|
||||
MIXUP: 0.8
|
||||
MIXCUT: 1.0
|
||||
TIMM_AUG:
|
||||
USE_LOADER: true
|
||||
RE_COUNT: 1
|
||||
RE_MODE: pixel
|
||||
RE_SPLIT: false
|
||||
RE_PROB: 0.25
|
||||
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
||||
HFLIP: 0.5
|
||||
VFLIP: 0.0
|
||||
COLOR_JITTER: 0.4
|
||||
INTERPOLATION: bicubic
|
||||
LOSS:
|
||||
LABEL_SMOOTHING: 0.1
|
||||
CUDNN:
|
||||
BENCHMARK: true
|
||||
DETERMINISTIC: false
|
||||
ENABLED: true
|
||||
DATASET:
|
||||
DATASET: 'imagenet'
|
||||
DATA_FORMAT: 'jpg'
|
||||
ROOT: 'DATASET/imagenet/'
|
||||
TEST_SET: 'val'
|
||||
TRAIN_SET: 'train'
|
||||
TEST:
|
||||
BATCH_SIZE_PER_GPU: 32
|
||||
IMAGE_SIZE: [384, 384]
|
||||
CENTER_CROP: False
|
||||
MODEL_FILE: ''
|
||||
INTERPOLATION: 3
|
||||
TRAIN:
|
||||
BATCH_SIZE_PER_GPU: 256
|
||||
LR: 0.00025
|
||||
IMAGE_SIZE: [384, 384]
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 300
|
||||
LR_SCHEDULER:
|
||||
METHOD: 'timm'
|
||||
ARGS:
|
||||
sched: 'cosine'
|
||||
warmup_epochs: 5
|
||||
warmup_lr: 0.000001
|
||||
min_lr: 0.00001
|
||||
cooldown_epochs: 10
|
||||
decay_rate: 0.1
|
||||
OPTIMIZER: adamW
|
||||
WD: 0.05
|
||||
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
||||
SHUFFLE: true
|
||||
DEBUG:
|
||||
DEBUG: false
|
|
@ -0,0 +1,84 @@
|
|||
OUTPUT_DIR: 'OUTPUT/'
|
||||
WORKERS: 6
|
||||
PRINT_FREQ: 500
|
||||
AMP:
|
||||
ENABLED: true
|
||||
|
||||
MODEL:
|
||||
NAME: cls_cvt
|
||||
SPEC:
|
||||
INIT: 'trunc_norm'
|
||||
NUM_STAGES: 3
|
||||
PATCH_SIZE: [7, 3, 3]
|
||||
PATCH_STRIDE: [4, 2, 2]
|
||||
PATCH_PADDING: [2, 1, 1]
|
||||
DIM_EMBED: [64, 192, 384]
|
||||
NUM_HEADS: [1, 3, 6]
|
||||
DEPTH: [1, 4, 16]
|
||||
MLP_RATIO: [4.0, 4.0, 4.0]
|
||||
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_PATH_RATE: [0.0, 0.0, 0.1]
|
||||
QKV_BIAS: [True, True, True]
|
||||
CLS_TOKEN: [False, False, True]
|
||||
POS_EMBED: [False, False, False]
|
||||
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
||||
KERNEL_QKV: [3, 3, 3]
|
||||
PADDING_KV: [1, 1, 1]
|
||||
STRIDE_KV: [2, 2, 2]
|
||||
PADDING_Q: [1, 1, 1]
|
||||
STRIDE_Q: [1, 1, 1]
|
||||
AUG:
|
||||
MIXUP_PROB: 1.0
|
||||
MIXUP: 0.8
|
||||
MIXCUT: 1.0
|
||||
TIMM_AUG:
|
||||
USE_LOADER: false
|
||||
RE_COUNT: 1
|
||||
RE_MODE: pixel
|
||||
RE_SPLIT: false
|
||||
RE_PROB: 0.25
|
||||
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
||||
HFLIP: 0.5
|
||||
VFLIP: 0.0
|
||||
COLOR_JITTER: 0.4
|
||||
INTERPOLATION: bicubic
|
||||
LOSS:
|
||||
LABEL_SMOOTHING: 0.1
|
||||
CUDNN:
|
||||
BENCHMARK: true
|
||||
DETERMINISTIC: false
|
||||
ENABLED: true
|
||||
DATASET:
|
||||
DATASET: 'imagenet'
|
||||
DATA_FORMAT: 'jpg'
|
||||
ROOT: 'DATASET/imagenet/'
|
||||
TEST_SET: 'val'
|
||||
TRAIN_SET: 'train'
|
||||
SAMPLER: repeated_aug
|
||||
TEST:
|
||||
BATCH_SIZE_PER_GPU: 32
|
||||
IMAGE_SIZE: [224, 224]
|
||||
MODEL_FILE: ''
|
||||
INTERPOLATION: 3
|
||||
TRAIN:
|
||||
BATCH_SIZE_PER_GPU: 128
|
||||
LR: 0.000125
|
||||
IMAGE_SIZE: [224, 224]
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 300
|
||||
LR_SCHEDULER:
|
||||
METHOD: 'timm'
|
||||
ARGS:
|
||||
sched: 'cosine'
|
||||
warmup_epochs: 5
|
||||
warmup_lr: 0.000001
|
||||
min_lr: 0.00001
|
||||
cooldown_epochs: 10
|
||||
decay_rate: 0.1
|
||||
OPTIMIZER: adamW
|
||||
WD: 0.1
|
||||
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
||||
SHUFFLE: true
|
||||
DEBUG:
|
||||
DEBUG: false
|
|
@ -0,0 +1,84 @@
|
|||
OUTPUT_DIR: 'OUTPUT/'
|
||||
WORKERS: 6
|
||||
PRINT_FREQ: 500
|
||||
AMP:
|
||||
ENABLED: true
|
||||
|
||||
MODEL:
|
||||
NAME: cls_cvt
|
||||
SPEC:
|
||||
INIT: 'trunc_norm'
|
||||
NUM_STAGES: 3
|
||||
PATCH_SIZE: [7, 3, 3]
|
||||
PATCH_STRIDE: [4, 2, 2]
|
||||
PATCH_PADDING: [2, 1, 1]
|
||||
DIM_EMBED: [64, 192, 384]
|
||||
NUM_HEADS: [1, 3, 6]
|
||||
DEPTH: [1, 4, 16]
|
||||
MLP_RATIO: [4.0, 4.0, 4.0]
|
||||
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_PATH_RATE: [0.0, 0.0, 0.1]
|
||||
QKV_BIAS: [True, True, True]
|
||||
CLS_TOKEN: [False, False, True]
|
||||
POS_EMBED: [False, False, False]
|
||||
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
||||
KERNEL_QKV: [3, 3, 3]
|
||||
PADDING_KV: [1, 1, 1]
|
||||
STRIDE_KV: [2, 2, 2]
|
||||
PADDING_Q: [1, 1, 1]
|
||||
STRIDE_Q: [1, 1, 1]
|
||||
AUG:
|
||||
MIXUP_PROB: 1.0
|
||||
MIXUP: 0.8
|
||||
MIXCUT: 1.0
|
||||
TIMM_AUG:
|
||||
USE_LOADER: true
|
||||
RE_COUNT: 1
|
||||
RE_MODE: pixel
|
||||
RE_SPLIT: false
|
||||
RE_PROB: 0.25
|
||||
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
||||
HFLIP: 0.5
|
||||
VFLIP: 0.0
|
||||
COLOR_JITTER: 0.4
|
||||
INTERPOLATION: bicubic
|
||||
LOSS:
|
||||
LABEL_SMOOTHING: 0.1
|
||||
CUDNN:
|
||||
BENCHMARK: true
|
||||
DETERMINISTIC: false
|
||||
ENABLED: true
|
||||
DATASET:
|
||||
DATASET: 'imagenet'
|
||||
DATA_FORMAT: 'jpg'
|
||||
ROOT: 'DATASET/imagenet/'
|
||||
TEST_SET: 'val'
|
||||
TRAIN_SET: 'train'
|
||||
TEST:
|
||||
BATCH_SIZE_PER_GPU: 32
|
||||
IMAGE_SIZE: [384, 384]
|
||||
MODEL_FILE: ''
|
||||
INTERPOLATION: 3
|
||||
CENTER_CROP: False
|
||||
TRAIN:
|
||||
BATCH_SIZE_PER_GPU: 128
|
||||
LR: 0.000125
|
||||
IMAGE_SIZE: [384, 384]
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 300
|
||||
LR_SCHEDULER:
|
||||
METHOD: 'timm'
|
||||
ARGS:
|
||||
sched: 'cosine'
|
||||
warmup_epochs: 5
|
||||
warmup_lr: 0.000001
|
||||
min_lr: 0.00001
|
||||
cooldown_epochs: 10
|
||||
decay_rate: 0.1
|
||||
OPTIMIZER: adamW
|
||||
WD: 0.1
|
||||
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
||||
SHUFFLE: true
|
||||
DEBUG:
|
||||
DEBUG: false
|
|
@ -0,0 +1,84 @@
|
|||
OUTPUT_DIR: 'OUTPUT/'
|
||||
WORKERS: 6
|
||||
PRINT_FREQ: 500
|
||||
AMP:
|
||||
ENABLED: true
|
||||
|
||||
MODEL:
|
||||
NAME: cls_cvt
|
||||
SPEC:
|
||||
INIT: 'trunc_norm'
|
||||
NUM_STAGES: 3
|
||||
PATCH_SIZE: [7, 3, 3]
|
||||
PATCH_STRIDE: [4, 2, 2]
|
||||
PATCH_PADDING: [2, 1, 1]
|
||||
DIM_EMBED: [192, 768, 1024]
|
||||
NUM_HEADS: [3, 12, 16]
|
||||
DEPTH: [2, 2, 20]
|
||||
MLP_RATIO: [4.0, 4.0, 4.0]
|
||||
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_RATE: [0.0, 0.0, 0.0]
|
||||
DROP_PATH_RATE: [0.0, 0.0, 0.3]
|
||||
QKV_BIAS: [True, True, True]
|
||||
CLS_TOKEN: [False, False, True]
|
||||
POS_EMBED: [False, False, False]
|
||||
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
||||
KERNEL_QKV: [3, 3, 3]
|
||||
PADDING_KV: [1, 1, 1]
|
||||
STRIDE_KV: [2, 2, 2]
|
||||
PADDING_Q: [1, 1, 1]
|
||||
STRIDE_Q: [1, 1, 1]
|
||||
AUG:
|
||||
MIXUP_PROB: 1.0
|
||||
MIXUP: 0.8
|
||||
MIXCUT: 1.0
|
||||
TIMM_AUG:
|
||||
USE_LOADER: true
|
||||
RE_COUNT: 1
|
||||
RE_MODE: pixel
|
||||
RE_SPLIT: false
|
||||
RE_PROB: 0.25
|
||||
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
||||
HFLIP: 0.5
|
||||
VFLIP: 0.0
|
||||
COLOR_JITTER: 0.4
|
||||
INTERPOLATION: bicubic
|
||||
LOSS:
|
||||
LABEL_SMOOTHING: 0.1
|
||||
CUDNN:
|
||||
BENCHMARK: true
|
||||
DETERMINISTIC: false
|
||||
ENABLED: true
|
||||
DATASET:
|
||||
DATASET: 'imagenet'
|
||||
DATA_FORMAT: 'jpg'
|
||||
ROOT: 'DATASET/imagenet/'
|
||||
TEST_SET: 'val'
|
||||
TRAIN_SET: 'train'
|
||||
TEST:
|
||||
BATCH_SIZE_PER_GPU: 32
|
||||
IMAGE_SIZE: [384, 384]
|
||||
MODEL_FILE: ''
|
||||
INTERPOLATION: 3
|
||||
CENTER_CROP: False
|
||||
TRAIN:
|
||||
BATCH_SIZE_PER_GPU: 128
|
||||
LR: 0.000125
|
||||
IMAGE_SIZE: [384, 384]
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 300
|
||||
LR_SCHEDULER:
|
||||
METHOD: 'timm'
|
||||
ARGS:
|
||||
sched: 'cosine'
|
||||
warmup_epochs: 5
|
||||
warmup_lr: 0.000001
|
||||
min_lr: 0.00001
|
||||
cooldown_epochs: 10
|
||||
decay_rate: 0.1
|
||||
OPTIMIZER: adamW
|
||||
WD: 0.1
|
||||
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
||||
SHUFFLE: true
|
||||
DEBUG:
|
||||
DEBUG: false
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
После Ширина: | Высота: | Размер: 923 KiB |
|
@ -0,0 +1,4 @@
|
|||
from .default import _C as config
|
||||
from .default import update_config
|
||||
from .default import _update_config_from_file
|
||||
from .default import save_config
|
|
@ -0,0 +1,203 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as op
|
||||
import yaml
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
from lib.utils.comm import comm
|
||||
|
||||
|
||||
_C = CN()
|
||||
|
||||
_C.BASE = ['']
|
||||
_C.NAME = ''
|
||||
_C.DATA_DIR = ''
|
||||
_C.DIST_BACKEND = 'nccl'
|
||||
_C.GPUS = (0,)
|
||||
# _C.LOG_DIR = ''
|
||||
_C.MULTIPROCESSING_DISTRIBUTED = True
|
||||
_C.OUTPUT_DIR = ''
|
||||
_C.PIN_MEMORY = True
|
||||
_C.PRINT_FREQ = 20
|
||||
_C.RANK = 0
|
||||
_C.VERBOSE = True
|
||||
_C.WORKERS = 4
|
||||
_C.MODEL_SUMMARY = False
|
||||
|
||||
_C.AMP = CN()
|
||||
_C.AMP.ENABLED = False
|
||||
_C.AMP.MEMORY_FORMAT = 'nchw'
|
||||
|
||||
# Cudnn related params
|
||||
_C.CUDNN = CN()
|
||||
_C.CUDNN.BENCHMARK = True
|
||||
_C.CUDNN.DETERMINISTIC = False
|
||||
_C.CUDNN.ENABLED = True
|
||||
|
||||
# common params for NETWORK
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.NAME = 'cls_hrnet'
|
||||
_C.MODEL.INIT_WEIGHTS = True
|
||||
_C.MODEL.PRETRAINED = ''
|
||||
_C.MODEL.PRETRAINED_LAYERS = ['*']
|
||||
_C.MODEL.NUM_CLASSES = 1000
|
||||
_C.MODEL.SPEC = CN(new_allowed=True)
|
||||
|
||||
_C.LOSS = CN(new_allowed=True)
|
||||
_C.LOSS.LABEL_SMOOTHING = 0.0
|
||||
_C.LOSS.LOSS = 'softmax'
|
||||
|
||||
# DATASET related params
|
||||
_C.DATASET = CN()
|
||||
_C.DATASET.ROOT = ''
|
||||
_C.DATASET.DATASET = 'imagenet'
|
||||
_C.DATASET.TRAIN_SET = 'train'
|
||||
_C.DATASET.TEST_SET = 'val'
|
||||
_C.DATASET.DATA_FORMAT = 'jpg'
|
||||
_C.DATASET.LABELMAP = ''
|
||||
_C.DATASET.TRAIN_TSV_LIST = []
|
||||
_C.DATASET.TEST_TSV_LIST = []
|
||||
_C.DATASET.SAMPLER = 'default'
|
||||
|
||||
_C.DATASET.TARGET_SIZE = -1
|
||||
|
||||
# training data augmentation
|
||||
_C.INPUT = CN()
|
||||
_C.INPUT.MEAN = [0.485, 0.456, 0.406]
|
||||
_C.INPUT.STD = [0.229, 0.224, 0.225]
|
||||
|
||||
# data augmentation
|
||||
_C.AUG = CN()
|
||||
_C.AUG.SCALE = (0.08, 1.0)
|
||||
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
|
||||
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
|
||||
_C.AUG.GRAY_SCALE = 0.0
|
||||
_C.AUG.GAUSSIAN_BLUR = 0.0
|
||||
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
|
||||
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
|
||||
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
|
||||
_C.AUG.MIXUP_PROB = 0.0
|
||||
_C.AUG.MIXUP = 0.0
|
||||
_C.AUG.MIXCUT = 0.0
|
||||
_C.AUG.MIXCUT_MINMAX = []
|
||||
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
||||
_C.AUG.MIXUP_MODE = 'batch'
|
||||
_C.AUG.MIXCUT_AND_MIXUP = False
|
||||
_C.AUG.INTERPOLATION = 2
|
||||
_C.AUG.TIMM_AUG = CN(new_allowed=True)
|
||||
_C.AUG.TIMM_AUG.USE_LOADER = False
|
||||
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
|
||||
|
||||
# train
|
||||
_C.TRAIN = CN()
|
||||
|
||||
_C.TRAIN.AUTO_RESUME = True
|
||||
_C.TRAIN.CHECKPOINT = ''
|
||||
_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
|
||||
_C.TRAIN.SCALE_LR = True
|
||||
_C.TRAIN.LR = 0.001
|
||||
|
||||
_C.TRAIN.OPTIMIZER = 'sgd'
|
||||
_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
|
||||
_C.TRAIN.MOMENTUM = 0.9
|
||||
_C.TRAIN.WD = 0.0001
|
||||
_C.TRAIN.WITHOUT_WD_LIST = []
|
||||
_C.TRAIN.NESTEROV = True
|
||||
# for adam
|
||||
_C.TRAIN.GAMMA1 = 0.99
|
||||
_C.TRAIN.GAMMA2 = 0.0
|
||||
|
||||
_C.TRAIN.BEGIN_EPOCH = 0
|
||||
_C.TRAIN.END_EPOCH = 100
|
||||
|
||||
_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
||||
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
|
||||
_C.TRAIN.SHUFFLE = True
|
||||
|
||||
_C.TRAIN.EVAL_BEGIN_EPOCH = 0
|
||||
|
||||
_C.TRAIN.DETECT_ANOMALY = False
|
||||
|
||||
_C.TRAIN.CLIP_GRAD_NORM = 0.0
|
||||
_C.TRAIN.SAVE_ALL_MODELS = False
|
||||
|
||||
# testing
|
||||
_C.TEST = CN()
|
||||
|
||||
# size of images for each device
|
||||
_C.TEST.BATCH_SIZE_PER_GPU = 32
|
||||
_C.TEST.CENTER_CROP = True
|
||||
_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
||||
_C.TEST.INTERPOLATION = 2
|
||||
_C.TEST.MODEL_FILE = ''
|
||||
_C.TEST.REAL_LABELS = False
|
||||
_C.TEST.VALID_LABELS = ''
|
||||
|
||||
_C.FINETUNE = CN()
|
||||
_C.FINETUNE.FINETUNE = False
|
||||
_C.FINETUNE.USE_TRAIN_AUG = False
|
||||
_C.FINETUNE.BASE_LR = 0.003
|
||||
_C.FINETUNE.BATCH_SIZE = 512
|
||||
_C.FINETUNE.EVAL_EVERY = 3000
|
||||
_C.FINETUNE.TRAIN_MODE = True
|
||||
# _C.FINETUNE.MODEL_FILE = ''
|
||||
_C.FINETUNE.FROZEN_LAYERS = []
|
||||
_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
|
||||
_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
|
||||
|
||||
# debug
|
||||
_C.DEBUG = CN()
|
||||
_C.DEBUG.DEBUG = False
|
||||
|
||||
|
||||
def _update_config_from_file(config, cfg_file):
|
||||
config.defrost()
|
||||
with open(cfg_file, 'r') as f:
|
||||
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
||||
if cfg:
|
||||
_update_config_from_file(
|
||||
config, op.join(op.dirname(cfg_file), cfg)
|
||||
)
|
||||
print('=> merge config from {}'.format(cfg_file))
|
||||
config.merge_from_file(cfg_file)
|
||||
config.freeze()
|
||||
|
||||
|
||||
def update_config(config, args):
|
||||
_update_config_from_file(config, args.cfg)
|
||||
|
||||
config.defrost()
|
||||
config.merge_from_list(args.opts)
|
||||
if config.TRAIN.SCALE_LR:
|
||||
config.TRAIN.LR *= comm.world_size
|
||||
file_name, _ = op.splitext(op.basename(args.cfg))
|
||||
config.NAME = file_name + config.NAME
|
||||
config.RANK = comm.rank
|
||||
|
||||
if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
|
||||
config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
|
||||
|
||||
if 'timm' == config.TRAIN.OPTIMIZER:
|
||||
config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
|
||||
|
||||
aug = config.AUG
|
||||
if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
|
||||
aug.MIXUP_PROB = 1.0
|
||||
config.freeze()
|
||||
|
||||
|
||||
def save_config(cfg, path):
|
||||
if comm.is_main_process():
|
||||
with open(path, 'w') as f:
|
||||
f.write(cfg.dump())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
with open(sys.argv[1], 'w') as f:
|
||||
print(_C, file=f)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if isinstance(output, list):
|
||||
output = output[-1]
|
||||
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size).item())
|
||||
return res
|
|
@ -0,0 +1,223 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
|
||||
from timm.data import Mixup
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from core.evaluate import accuracy
|
||||
from utils.comm import comm
|
||||
|
||||
|
||||
def train_one_epoch(config, train_loader, model, criterion, optimizer, epoch,
|
||||
output_dir, tb_log_dir, writer_dict, scaler=None):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
logging.info('=> switch to train mode')
|
||||
model.train()
|
||||
|
||||
aug = config.AUG
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=aug.MIXUP, cutmix_alpha=aug.MIXCUT,
|
||||
cutmix_minmax=aug.MIXCUT_MINMAX if aug.MIXCUT_MINMAX else None,
|
||||
prob=aug.MIXUP_PROB, switch_prob=aug.MIXUP_SWITCH_PROB,
|
||||
mode=aug.MIXUP_MODE, label_smoothing=config.LOSS.LABEL_SMOOTHING,
|
||||
num_classes=config.MODEL.NUM_CLASSES
|
||||
) if aug.MIXUP_PROB > 0.0 else None
|
||||
end = time.time()
|
||||
for i, (x, y) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# compute output
|
||||
x = x.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
|
||||
if mixup_fn:
|
||||
x, y = mixup_fn(x, y)
|
||||
|
||||
with autocast(enabled=config.AMP.ENABLED):
|
||||
if config.AMP.ENABLED and config.AMP.MEMORY_FORMAT == 'nwhc':
|
||||
x = x.contiguous(memory_format=torch.channels_last)
|
||||
y = y.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
outputs = model(x)
|
||||
loss = criterion(outputs, y)
|
||||
|
||||
# compute gradient and do update step
|
||||
optimizer.zero_grad()
|
||||
is_second_order = hasattr(optimizer, 'is_second_order') \
|
||||
and optimizer.is_second_order
|
||||
|
||||
scaler.scale(loss).backward(create_graph=is_second_order)
|
||||
|
||||
if config.TRAIN.CLIP_GRAD_NORM > 0.0:
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
scaler.unscale_(optimizer)
|
||||
|
||||
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), config.TRAIN.CLIP_GRAD_NORM
|
||||
)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), x.size(0))
|
||||
|
||||
if mixup_fn:
|
||||
y = torch.argmax(y, dim=1)
|
||||
prec1, prec5 = accuracy(outputs, y, (1, 5))
|
||||
|
||||
top1.update(prec1, x.size(0))
|
||||
top5.update(prec5, x.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % config.PRINT_FREQ == 0:
|
||||
msg = '=> Epoch[{0}][{1}/{2}]: ' \
|
||||
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
|
||||
'Speed {speed:.1f} samples/s\t' \
|
||||
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
|
||||
'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
|
||||
'Accuracy@1 {top1.val:.3f} ({top1.avg:.3f})\t' \
|
||||
'Accuracy@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
|
||||
epoch, i, len(train_loader),
|
||||
batch_time=batch_time,
|
||||
speed=x.size(0)/batch_time.val,
|
||||
data_time=data_time, loss=losses, top1=top1, top5=top5)
|
||||
logging.info(msg)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if writer_dict and comm.is_main_process():
|
||||
writer = writer_dict['writer']
|
||||
global_steps = writer_dict['train_global_steps']
|
||||
writer.add_scalar('train_loss', losses.avg, global_steps)
|
||||
writer.add_scalar('train_top1', top1.avg, global_steps)
|
||||
writer_dict['train_global_steps'] = global_steps + 1
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test(config, val_loader, model, criterion, output_dir, tb_log_dir,
|
||||
writer_dict=None, distributed=False, real_labels=None,
|
||||
valid_labels=None):
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
logging.info('=> switch to eval mode')
|
||||
model.eval()
|
||||
|
||||
end = time.time()
|
||||
for i, (x, y) in enumerate(val_loader):
|
||||
# compute output
|
||||
x = x.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
|
||||
outputs = model(x)
|
||||
if valid_labels:
|
||||
outputs = outputs[:, valid_labels]
|
||||
|
||||
loss = criterion(outputs, y)
|
||||
|
||||
if real_labels and not distributed:
|
||||
real_labels.add_result(outputs)
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), x.size(0))
|
||||
|
||||
prec1, prec5 = accuracy(outputs, y, (1, 5))
|
||||
top1.update(prec1, x.size(0))
|
||||
top5.update(prec5, x.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
logging.info('=> synchronize...')
|
||||
comm.synchronize()
|
||||
top1_acc, top5_acc, loss_avg = map(
|
||||
_meter_reduce if distributed else lambda x: x.avg,
|
||||
[top1, top5, losses]
|
||||
)
|
||||
|
||||
if real_labels and not distributed:
|
||||
real_top1 = real_labels.get_accuracy(k=1)
|
||||
real_top5 = real_labels.get_accuracy(k=5)
|
||||
msg = '=> TEST using Reassessed labels:\t' \
|
||||
'Error@1 {error1:.3f}%\t' \
|
||||
'Error@5 {error5:.3f}%\t' \
|
||||
'Accuracy@1 {top1:.3f}%\t' \
|
||||
'Accuracy@5 {top5:.3f}%\t'.format(
|
||||
top1=real_top1,
|
||||
top5=real_top5,
|
||||
error1=100-real_top1,
|
||||
error5=100-real_top5
|
||||
)
|
||||
logging.info(msg)
|
||||
|
||||
if comm.is_main_process():
|
||||
msg = '=> TEST:\t' \
|
||||
'Loss {loss_avg:.4f}\t' \
|
||||
'Error@1 {error1:.3f}%\t' \
|
||||
'Error@5 {error5:.3f}%\t' \
|
||||
'Accuracy@1 {top1:.3f}%\t' \
|
||||
'Accuracy@5 {top5:.3f}%\t'.format(
|
||||
loss_avg=loss_avg, top1=top1_acc,
|
||||
top5=top5_acc, error1=100-top1_acc,
|
||||
error5=100-top5_acc
|
||||
)
|
||||
logging.info(msg)
|
||||
|
||||
if writer_dict and comm.is_main_process():
|
||||
writer = writer_dict['writer']
|
||||
global_steps = writer_dict['valid_global_steps']
|
||||
writer.add_scalar('valid_loss', loss_avg, global_steps)
|
||||
writer.add_scalar('valid_top1', top1_acc, global_steps)
|
||||
writer_dict['valid_global_steps'] = global_steps + 1
|
||||
|
||||
logging.info('=> switch to train mode')
|
||||
model.train()
|
||||
|
||||
return top1_acc
|
||||
|
||||
|
||||
def _meter_reduce(meter):
|
||||
rank = comm.local_rank
|
||||
meter_sum = torch.FloatTensor([meter.sum]).cuda(rank)
|
||||
meter_count = torch.FloatTensor([meter.count]).cuda(rank)
|
||||
torch.distributed.reduce(meter_sum, 0)
|
||||
torch.distributed.reduce(meter_count, 0)
|
||||
meter_avg = meter_sum / meter_count
|
||||
|
||||
return meter_avg.item()
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
|
@ -0,0 +1,49 @@
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def linear_combination(x, y, epsilon):
|
||||
return epsilon*x + (1-epsilon)*y
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction='mean'):
|
||||
return loss.mean() if reduction == 'mean' \
|
||||
else loss.sum() if reduction == 'sum' else loss
|
||||
|
||||
|
||||
class LabelSmoothingCrossEntropy(nn.Module):
|
||||
def __init__(self, epsilon=0.1, reduction='mean'):
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, preds, target):
|
||||
n = preds.size()[-1]
|
||||
log_preds = F.log_softmax(preds, dim=-1)
|
||||
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
|
||||
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
|
||||
return linear_combination(loss/n, nll, self.epsilon)
|
||||
|
||||
|
||||
class SoftTargetCrossEntropy(nn.Module):
|
||||
def __init__(self):
|
||||
super(SoftTargetCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, x, target):
|
||||
loss = th.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def build_criterion(config, train=True):
|
||||
if config.AUG.MIXUP_PROB > 0.0 and config.LOSS.LOSS == 'softmax':
|
||||
criterion = SoftTargetCrossEntropy() \
|
||||
if train else nn.CrossEntropyLoss()
|
||||
elif config.LOSS.LABEL_SMOOTHING > 0.0 and config.LOSS.LOSS == 'softmax':
|
||||
criterion = LabelSmoothingCrossEntropy(config.LOSS.LABEL_SMOOTHING)
|
||||
elif config.LOSS.LOSS == 'softmax':
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError('Unkown loss {}'.format(config.LOSS.LOSS))
|
||||
|
||||
return criterion
|
|
@ -0,0 +1,2 @@
|
|||
from .build import build_dataloader
|
||||
from .imagenet.real_labels import RealLabelsImagenet
|
|
@ -0,0 +1,115 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from timm.data import create_loader
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from .transformas import build_transforms
|
||||
from .samplers import RASampler
|
||||
|
||||
|
||||
def build_dataset(cfg, is_train):
|
||||
dataset = None
|
||||
if 'imagenet' in cfg.DATASET.DATASET:
|
||||
dataset = _build_imagenet_dataset(cfg, is_train)
|
||||
else:
|
||||
raise ValueError('Unkown dataset: {}'.format(cfg.DATASET.DATASET))
|
||||
return dataset
|
||||
|
||||
|
||||
def _build_image_folder_dataset(cfg, is_train):
|
||||
transforms = build_transforms(cfg, is_train)
|
||||
|
||||
dataset_name = cfg.DATASET.TRAIN_SET if is_train else cfg.DATASET.TEST_SET
|
||||
dataset = datasets.ImageFolder(
|
||||
os.path.join(cfg.DATASET.ROOT, dataset_name), transforms
|
||||
)
|
||||
logging.info(
|
||||
'=> load samples: {}, is_train: {}'
|
||||
.format(len(dataset), is_train)
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _build_imagenet_dataset(cfg, is_train):
|
||||
transforms = build_transforms(cfg, is_train)
|
||||
|
||||
dataset_name = cfg.DATASET.TRAIN_SET if is_train else cfg.DATASET.TEST_SET
|
||||
dataset = datasets.ImageFolder(
|
||||
os.path.join(cfg.DATASET.ROOT, dataset_name), transforms
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(cfg, is_train=True, distributed=False):
|
||||
if is_train:
|
||||
batch_size_per_gpu = cfg.TRAIN.BATCH_SIZE_PER_GPU
|
||||
shuffle = True
|
||||
else:
|
||||
batch_size_per_gpu = cfg.TEST.BATCH_SIZE_PER_GPU
|
||||
shuffle = False
|
||||
|
||||
dataset = build_dataset(cfg, is_train)
|
||||
|
||||
if distributed:
|
||||
if is_train and cfg.DATASET.SAMPLER == 'repeated_aug':
|
||||
logging.info('=> use repeated aug sampler')
|
||||
sampler = RASampler(dataset, shuffle=shuffle)
|
||||
else:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, shuffle=shuffle
|
||||
)
|
||||
shuffle = False
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
if cfg.AUG.TIMM_AUG.USE_LOADER and is_train:
|
||||
logging.info('=> use timm loader for training')
|
||||
timm_cfg = cfg.AUG.TIMM_AUG
|
||||
data_loader = create_loader(
|
||||
dataset,
|
||||
input_size=cfg.TRAIN.IMAGE_SIZE[0],
|
||||
batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
no_aug=False,
|
||||
re_prob=timm_cfg.RE_PROB,
|
||||
re_mode=timm_cfg.RE_MODE,
|
||||
re_count=timm_cfg.RE_COUNT,
|
||||
re_split=timm_cfg.RE_SPLIT,
|
||||
scale=cfg.AUG.SCALE,
|
||||
ratio=cfg.AUG.RATIO,
|
||||
hflip=timm_cfg.HFLIP,
|
||||
vflip=timm_cfg.VFLIP,
|
||||
color_jitter=timm_cfg.COLOR_JITTER,
|
||||
auto_augment=timm_cfg.AUTO_AUGMENT,
|
||||
num_aug_splits=0,
|
||||
interpolation=timm_cfg.INTERPOLATION,
|
||||
mean=cfg.INPUT.MEAN,
|
||||
std=cfg.INPUT.STD,
|
||||
num_workers=cfg.WORKERS,
|
||||
distributed=distributed,
|
||||
collate_fn=None,
|
||||
pin_memory=cfg.PIN_MEMORY,
|
||||
use_multi_epochs_loader=True
|
||||
)
|
||||
else:
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size_per_gpu,
|
||||
shuffle=shuffle,
|
||||
num_workers=cfg.WORKERS,
|
||||
pin_memory=cfg.PIN_MEMORY,
|
||||
sampler=sampler,
|
||||
drop_last=True if is_train else False,
|
||||
)
|
||||
|
||||
return data_loader
|
|
@ -0,0 +1,45 @@
|
|||
""" Real labels evaluator for ImageNet
|
||||
Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
|
||||
Based on Numpy example at https://github.com/google-research/reassessed-imagenet
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RealLabelsImagenet:
|
||||
|
||||
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
|
||||
with open(real_json) as real_labels:
|
||||
real_labels = json.load(real_labels)
|
||||
real_labels = {
|
||||
f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels
|
||||
for i, labels in enumerate(real_labels)
|
||||
}
|
||||
self.real_labels = real_labels
|
||||
self.filenames = filenames
|
||||
assert len(self.filenames) == len(self.real_labels)
|
||||
self.topk = topk
|
||||
self.is_correct = {k: [] for k in topk}
|
||||
self.sample_idx = 0
|
||||
|
||||
def add_result(self, output):
|
||||
maxk = max(self.topk)
|
||||
_, pred_batch = output.topk(maxk, 1, True, True)
|
||||
pred_batch = pred_batch.cpu().numpy()
|
||||
for pred in pred_batch:
|
||||
filename = self.filenames[self.sample_idx]
|
||||
filename = os.path.basename(filename)
|
||||
if self.real_labels[filename]:
|
||||
for k in self.topk:
|
||||
self.is_correct[k].append(
|
||||
any([p in self.real_labels[filename] for p in pred[:k]]))
|
||||
self.sample_idx += 1
|
||||
|
||||
def get_accuracy(self, k=None):
|
||||
if k is None:
|
||||
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
|
||||
else:
|
||||
return float(np.mean(self.is_correct[k])) * 100
|
|
@ -0,0 +1 @@
|
|||
from .ra_sampler import RASampler
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the CC-by-NC license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||||
with repeated augmentation.
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices = [ele for ele in indices for i in range(3)]
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[:self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
|
@ -0,0 +1 @@
|
|||
from .build import build_transforms
|
|
@ -0,0 +1,124 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from timm.data import create_transform
|
||||
from PIL import ImageFilter
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, x):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return x
|
||||
|
||||
|
||||
def get_resolution(original_resolution):
|
||||
"""Takes (H,W) and returns (precrop, crop)."""
|
||||
area = original_resolution[0] * original_resolution[1]
|
||||
return (160, 128) if area < 96*96 else (512, 480)
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
if cfg.AUG.TIMM_AUG.USE_TRANSFORM and is_train:
|
||||
logging.info('=> use timm transform for training')
|
||||
timm_cfg = cfg.AUG.TIMM_AUG
|
||||
transforms = create_transform(
|
||||
input_size=cfg.TRAIN.IMAGE_SIZE[0],
|
||||
is_training=True,
|
||||
use_prefetcher=False,
|
||||
no_aug=False,
|
||||
re_prob=timm_cfg.RE_PROB,
|
||||
re_mode=timm_cfg.RE_MODE,
|
||||
re_count=timm_cfg.RE_COUNT,
|
||||
scale=cfg.AUG.SCALE,
|
||||
ratio=cfg.AUG.RATIO,
|
||||
hflip=timm_cfg.HFLIP,
|
||||
vflip=timm_cfg.VFLIP,
|
||||
color_jitter=timm_cfg.COLOR_JITTER,
|
||||
auto_augment=timm_cfg.AUTO_AUGMENT,
|
||||
interpolation=timm_cfg.INTERPOLATION,
|
||||
mean=cfg.INPUT.MEAN,
|
||||
std=cfg.INPUT.STD,
|
||||
)
|
||||
|
||||
return transforms
|
||||
|
||||
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
|
||||
normalize = T.Normalize(mean=cfg.INPUT.MEAN, std=cfg.INPUT.STD)
|
||||
|
||||
transforms = None
|
||||
if is_train:
|
||||
if cfg.FINETUNE.FINETUNE and not cfg.FINETUNE.USE_TRAIN_AUG:
|
||||
# precrop, crop = get_resolution(cfg.TRAIN.IMAGE_SIZE)
|
||||
crop = cfg.TRAIN.IMAGE_SIZE[0]
|
||||
precrop = crop + 32
|
||||
transforms = T.Compose([
|
||||
T.Resize(
|
||||
(precrop, precrop),
|
||||
interpolation=cfg.AUG.INTERPOLATION
|
||||
),
|
||||
T.RandomCrop((crop, crop)),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
else:
|
||||
aug = cfg.AUG
|
||||
scale = aug.SCALE
|
||||
ratio = aug.RATIO
|
||||
ts = [
|
||||
T.RandomResizedCrop(
|
||||
cfg.TRAIN.IMAGE_SIZE[0], scale=scale, ratio=ratio,
|
||||
interpolation=cfg.AUG.INTERPOLATION
|
||||
),
|
||||
T.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
cj = aug.COLOR_JITTER
|
||||
if cj[-1] > 0.0:
|
||||
ts.append(T.RandomApply([T.ColorJitter(*cj[:-1])], p=cj[-1]))
|
||||
|
||||
gs = aug.GRAY_SCALE
|
||||
if gs > 0.0:
|
||||
ts.append(T.RandomGrayscale(gs))
|
||||
|
||||
gb = aug.GAUSSIAN_BLUR
|
||||
if gb > 0.0:
|
||||
ts.append(T.RandomApply([GaussianBlur([.1, 2.])], p=gb))
|
||||
|
||||
ts.append(T.ToTensor())
|
||||
ts.append(normalize)
|
||||
|
||||
transforms = T.Compose(ts)
|
||||
else:
|
||||
if cfg.TEST.CENTER_CROP:
|
||||
transforms = T.Compose([
|
||||
T.Resize(
|
||||
int(cfg.TEST.IMAGE_SIZE[0] / 0.875),
|
||||
interpolation=cfg.TEST.INTERPOLATION
|
||||
),
|
||||
T.CenterCrop(cfg.TEST.IMAGE_SIZE[0]),
|
||||
T.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
else:
|
||||
transforms = T.Compose([
|
||||
T.Resize(
|
||||
(cfg.TEST.IMAGE_SIZE[1], cfg.TEST.IMAGE_SIZE[0]),
|
||||
interpolation=cfg.TEST.INTERPOLATION
|
||||
),
|
||||
T.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
return transforms
|
|
@ -0,0 +1,9 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .cls_cvt import *
|
||||
|
||||
from .registry import *
|
||||
|
||||
from .build import build_model
|
|
@ -0,0 +1,10 @@
|
|||
from .registry import model_entrypoints
|
||||
from .registry import is_model
|
||||
|
||||
|
||||
def build_model(config, **kwargs):
|
||||
model_name = config.MODEL.NAME
|
||||
if not is_model(model_name):
|
||||
raise ValueError(f'Unkown model: {model_name}')
|
||||
|
||||
return model_entrypoints(model_name)(config, **kwargs)
|
|
@ -0,0 +1,645 @@
|
|||
from functools import partial
|
||||
from itertools import repeat
|
||||
from torch._six import container_abcs
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, container_abcs.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
num_heads,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
method='dw_bn',
|
||||
kernel_size=3,
|
||||
stride_kv=1,
|
||||
stride_q=1,
|
||||
padding_kv=1,
|
||||
padding_q=1,
|
||||
with_cls_token=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.stride_kv = stride_kv
|
||||
self.stride_q = stride_q
|
||||
self.dim = dim_out
|
||||
self.num_heads = num_heads
|
||||
# head_dim = self.qkv_dim // num_heads
|
||||
self.scale = dim_out ** -0.5
|
||||
self.with_cls_token = with_cls_token
|
||||
|
||||
self.conv_proj_q = self._build_projection(
|
||||
dim_in, dim_out, kernel_size, padding_q,
|
||||
stride_q, 'linear' if method == 'avg' else method
|
||||
)
|
||||
self.conv_proj_k = self._build_projection(
|
||||
dim_in, dim_out, kernel_size, padding_kv,
|
||||
stride_kv, method
|
||||
)
|
||||
self.conv_proj_v = self._build_projection(
|
||||
dim_in, dim_out, kernel_size, padding_kv,
|
||||
stride_kv, method
|
||||
)
|
||||
|
||||
self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
||||
self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
||||
self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def _build_projection(self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
kernel_size,
|
||||
padding,
|
||||
stride,
|
||||
method):
|
||||
if method == 'dw_bn':
|
||||
proj = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(
|
||||
dim_in,
|
||||
dim_in,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
groups=dim_in
|
||||
)),
|
||||
('bn', nn.BatchNorm2d(dim_in)),
|
||||
('rearrage', Rearrange('b c h w -> b (h w) c')),
|
||||
]))
|
||||
elif method == 'avg':
|
||||
proj = nn.Sequential(OrderedDict([
|
||||
('avg', nn.AvgPool2d(
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
ceil_mode=True
|
||||
)),
|
||||
('rearrage', Rearrange('b c h w -> b (h w) c')),
|
||||
]))
|
||||
elif method == 'linear':
|
||||
proj = None
|
||||
else:
|
||||
raise ValueError('Unknown method ({})'.format(method))
|
||||
|
||||
return proj
|
||||
|
||||
def forward_conv(self, x, h, w):
|
||||
if self.with_cls_token:
|
||||
cls_token, x = torch.split(x, [1, h*w], 1)
|
||||
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
|
||||
if self.conv_proj_q is not None:
|
||||
q = self.conv_proj_q(x)
|
||||
else:
|
||||
q = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
||||
if self.conv_proj_k is not None:
|
||||
k = self.conv_proj_k(x)
|
||||
else:
|
||||
k = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
||||
if self.conv_proj_v is not None:
|
||||
v = self.conv_proj_v(x)
|
||||
else:
|
||||
v = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
||||
if self.with_cls_token:
|
||||
q = torch.cat((cls_token, q), dim=1)
|
||||
k = torch.cat((cls_token, k), dim=1)
|
||||
v = torch.cat((cls_token, v), dim=1)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, h, w):
|
||||
if (
|
||||
self.conv_proj_q is not None
|
||||
or self.conv_proj_k is not None
|
||||
or self.conv_proj_v is not None
|
||||
):
|
||||
q, k, v = self.forward_conv(x, h, w)
|
||||
|
||||
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
|
||||
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
|
||||
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
|
||||
|
||||
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
|
||||
attn = F.softmax(attn_score, dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
|
||||
x = rearrange(x, 'b h t d -> b t (h d)')
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def compute_macs(module, input, output):
|
||||
# T: num_token
|
||||
# S: num_token
|
||||
input = input[0]
|
||||
flops = 0
|
||||
|
||||
_, T, C = input.shape
|
||||
H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T))
|
||||
|
||||
H_Q = H / module.stride_q
|
||||
W_Q = H / module.stride_q
|
||||
T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q
|
||||
|
||||
H_KV = H / module.stride_kv
|
||||
W_KV = W / module.stride_kv
|
||||
T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV
|
||||
|
||||
# C = module.dim
|
||||
# S = T
|
||||
# Scaled-dot-product macs
|
||||
# [B x T x C] x [B x C x T] --> [B x T x S]
|
||||
# multiplication-addition is counted as 1 because operations can be fused
|
||||
flops += T_Q * T_KV * module.dim
|
||||
# [B x T x S] x [B x S x C] --> [B x T x C]
|
||||
flops += T_Q * module.dim * T_KV
|
||||
|
||||
if (
|
||||
hasattr(module, 'conv_proj_q')
|
||||
and hasattr(module.conv_proj_q, 'conv')
|
||||
):
|
||||
params = sum(
|
||||
[
|
||||
p.numel()
|
||||
for p in module.conv_proj_q.conv.parameters()
|
||||
]
|
||||
)
|
||||
flops += params * H_Q * W_Q
|
||||
|
||||
if (
|
||||
hasattr(module, 'conv_proj_k')
|
||||
and hasattr(module.conv_proj_k, 'conv')
|
||||
):
|
||||
params = sum(
|
||||
[
|
||||
p.numel()
|
||||
for p in module.conv_proj_k.conv.parameters()
|
||||
]
|
||||
)
|
||||
flops += params * H_KV * W_KV
|
||||
|
||||
if (
|
||||
hasattr(module, 'conv_proj_v')
|
||||
and hasattr(module.conv_proj_v, 'conv')
|
||||
):
|
||||
params = sum(
|
||||
[
|
||||
p.numel()
|
||||
for p in module.conv_proj_v.conv.parameters()
|
||||
]
|
||||
)
|
||||
flops += params * H_KV * W_KV
|
||||
|
||||
params = sum([p.numel() for p in module.proj_q.parameters()])
|
||||
flops += params * T_Q
|
||||
params = sum([p.numel() for p in module.proj_k.parameters()])
|
||||
flops += params * T_KV
|
||||
params = sum([p.numel() for p in module.proj_v.parameters()])
|
||||
flops += params * T_KV
|
||||
params = sum([p.numel() for p in module.proj.parameters()])
|
||||
flops += params * T
|
||||
|
||||
module.__flops__ += flops
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.with_cls_token = kwargs['with_cls_token']
|
||||
|
||||
self.norm1 = norm_layer(dim_in)
|
||||
self.attn = Attention(
|
||||
dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) \
|
||||
if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
|
||||
dim_mlp_hidden = int(dim_out * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim_out,
|
||||
hidden_features=dim_mlp_hidden,
|
||||
act_layer=act_layer,
|
||||
drop=drop
|
||||
)
|
||||
|
||||
def forward(self, x, h, w):
|
||||
res = x
|
||||
|
||||
x = self.norm1(x)
|
||||
attn = self.attn(x, h, w)
|
||||
x = res + self.drop_path(attn)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvEmbed(nn.Module):
|
||||
""" Image to Conv Embedding
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
in_chans=3,
|
||||
embed_dim=64,
|
||||
stride=4,
|
||||
padding=2,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
def __init__(self,
|
||||
patch_size=16,
|
||||
patch_stride=16,
|
||||
patch_padding=0,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init='trunc_norm',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.rearrage = None
|
||||
|
||||
self.patch_embed = ConvEmbed(
|
||||
# img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
stride=patch_stride,
|
||||
padding=patch_padding,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer
|
||||
)
|
||||
|
||||
with_cls_token = kwargs['with_cls_token']
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros(1, 1, embed_dim)
|
||||
)
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
blocks = []
|
||||
for j in range(depth):
|
||||
blocks.append(
|
||||
Block(
|
||||
dim_in=embed_dim,
|
||||
dim_out=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[j],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
if self.cls_token is not None:
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
if init == 'xavier':
|
||||
self.apply(self._init_weights_xavier)
|
||||
else:
|
||||
self.apply(self._init_weights_trunc_normal)
|
||||
|
||||
def _init_weights_trunc_normal(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
logging.info('=> init weight of Linear from trunc norm')
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
logging.info('=> init bias of Linear to zeros')
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def _init_weights_xavier(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
logging.info('=> init weight of Linear from xavier uniform')
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
logging.info('=> init bias of Linear to zeros')
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
B, C, H, W = x.size()
|
||||
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
||||
cls_tokens = None
|
||||
if self.cls_token is not None:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x, H, W)
|
||||
|
||||
if self.cls_token is not None:
|
||||
cls_tokens, x = torch.split(x, [1, H*W], 1)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
||||
|
||||
return x, cls_tokens
|
||||
|
||||
|
||||
class ConvolutionalVisionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init='trunc_norm',
|
||||
spec=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.num_stages = spec['NUM_STAGES']
|
||||
for i in range(self.num_stages):
|
||||
kwargs = {
|
||||
'patch_size': spec['PATCH_SIZE'][i],
|
||||
'patch_stride': spec['PATCH_STRIDE'][i],
|
||||
'patch_padding': spec['PATCH_PADDING'][i],
|
||||
'embed_dim': spec['DIM_EMBED'][i],
|
||||
'depth': spec['DEPTH'][i],
|
||||
'num_heads': spec['NUM_HEADS'][i],
|
||||
'mlp_ratio': spec['MLP_RATIO'][i],
|
||||
'qkv_bias': spec['QKV_BIAS'][i],
|
||||
'drop_rate': spec['DROP_RATE'][i],
|
||||
'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
|
||||
'drop_path_rate': spec['DROP_PATH_RATE'][i],
|
||||
'with_cls_token': spec['CLS_TOKEN'][i],
|
||||
'method': spec['QKV_PROJ_METHOD'][i],
|
||||
'kernel_size': spec['KERNEL_QKV'][i],
|
||||
'padding_q': spec['PADDING_Q'][i],
|
||||
'padding_kv': spec['PADDING_KV'][i],
|
||||
'stride_kv': spec['STRIDE_KV'][i],
|
||||
'stride_q': spec['STRIDE_Q'][i],
|
||||
}
|
||||
|
||||
stage = VisionTransformer(
|
||||
in_chans=in_chans,
|
||||
init=init,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**kwargs
|
||||
)
|
||||
setattr(self, f'stage{i}', stage)
|
||||
|
||||
in_chans = spec['DIM_EMBED'][i]
|
||||
|
||||
dim_embed = spec['DIM_EMBED'][-1]
|
||||
self.norm = norm_layer(dim_embed)
|
||||
self.cls_token = spec['CLS_TOKEN'][-1]
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity()
|
||||
trunc_normal_(self.head.weight, std=0.02)
|
||||
|
||||
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
|
||||
if os.path.isfile(pretrained):
|
||||
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
||||
logging.info(f'=> loading pretrained model {pretrained}')
|
||||
model_dict = self.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_dict.items()
|
||||
if k in model_dict.keys()
|
||||
}
|
||||
need_init_state_dict = {}
|
||||
for k, v in pretrained_dict.items():
|
||||
need_init = (
|
||||
k.split('.')[0] in pretrained_layers
|
||||
or pretrained_layers[0] is '*'
|
||||
)
|
||||
if need_init:
|
||||
if verbose:
|
||||
logging.info(f'=> init {k} from {pretrained}')
|
||||
if 'pos_embed' in k and v.size() != model_dict[k].size():
|
||||
size_pretrained = v.size()
|
||||
size_new = model_dict[k].size()
|
||||
logging.info(
|
||||
'=> load_pretrained: resized variant: {} to {}'
|
||||
.format(size_pretrained, size_new)
|
||||
)
|
||||
|
||||
ntok_new = size_new[1]
|
||||
ntok_new -= 1
|
||||
|
||||
posemb_tok, posemb_grid = v[:, :1], v[0, 1:]
|
||||
|
||||
gs_old = int(np.sqrt(len(posemb_grid)))
|
||||
gs_new = int(np.sqrt(ntok_new))
|
||||
|
||||
logging.info(
|
||||
'=> load_pretrained: grid-size from {} to {}'
|
||||
.format(gs_old, gs_new)
|
||||
)
|
||||
|
||||
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
||||
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
||||
posemb_grid = scipy.ndimage.zoom(
|
||||
posemb_grid, zoom, order=1
|
||||
)
|
||||
posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
|
||||
v = torch.tensor(
|
||||
np.concatenate([posemb_tok, posemb_grid], axis=1)
|
||||
)
|
||||
|
||||
need_init_state_dict[k] = v
|
||||
self.load_state_dict(need_init_state_dict, strict=False)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
layers = set()
|
||||
for i in range(self.num_stages):
|
||||
layers.add(f'stage{i}.pos_embed')
|
||||
layers.add(f'stage{i}.cls_token')
|
||||
|
||||
return layers
|
||||
|
||||
def forward_features(self, x):
|
||||
for i in range(self.num_stages):
|
||||
x, cls_tokens = getattr(self, f'stage{i}')(x)
|
||||
|
||||
if self.cls_token:
|
||||
x = self.norm(cls_tokens)
|
||||
x = torch.squeeze(x)
|
||||
else:
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
x = self.norm(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def get_cls_model(config, **kwargs):
|
||||
msvit_spec = config.MODEL.SPEC
|
||||
msvit = ConvolutionalVisionTransformer(
|
||||
in_chans=3,
|
||||
num_classes=config.MODEL.NUM_CLASSES,
|
||||
act_layer=QuickGELU,
|
||||
norm_layer=partial(LayerNorm, eps=1e-5),
|
||||
init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
|
||||
spec=msvit_spec
|
||||
)
|
||||
|
||||
if config.MODEL.INIT_WEIGHTS:
|
||||
msvit.init_weights(
|
||||
config.MODEL.PRETRAINED,
|
||||
config.MODEL.PRETRAINED_LAYERS,
|
||||
config.VERBOSE
|
||||
)
|
||||
|
||||
return msvit
|
|
@ -0,0 +1,18 @@
|
|||
_model_entrypoints = {}
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
module_name_split = fn.__module__.split('.')
|
||||
model_name = module_name_split[-1]
|
||||
|
||||
_model_entrypoints[model_name] = fn
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def model_entrypoints(model_name):
|
||||
return _model_entrypoints[model_name]
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
return model_name in _model_entrypoints
|
|
@ -0,0 +1 @@
|
|||
from .build import build_optimizer
|
|
@ -0,0 +1,155 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from timm.optim import create_optimizer
|
||||
|
||||
|
||||
def _is_depthwise(m):
|
||||
return (
|
||||
isinstance(m, nn.Conv2d)
|
||||
and m.groups == m.in_channels
|
||||
and m.groups == m.out_channels
|
||||
)
|
||||
|
||||
|
||||
def set_wd(cfg, model):
|
||||
without_decay_list = cfg.TRAIN.WITHOUT_WD_LIST
|
||||
without_decay_depthwise = []
|
||||
without_decay_norm = []
|
||||
for m in model.modules():
|
||||
if _is_depthwise(m) and 'dw' in without_decay_list:
|
||||
without_decay_depthwise.append(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d) and 'bn' in without_decay_list:
|
||||
without_decay_norm.append(m.weight)
|
||||
without_decay_norm.append(m.bias)
|
||||
elif isinstance(m, nn.GroupNorm) and 'gn' in without_decay_list:
|
||||
without_decay_norm.append(m.weight)
|
||||
without_decay_norm.append(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm) and 'ln' in without_decay_list:
|
||||
without_decay_norm.append(m.weight)
|
||||
without_decay_norm.append(m.bias)
|
||||
|
||||
with_decay = []
|
||||
without_decay = []
|
||||
|
||||
skip = {}
|
||||
if hasattr(model, 'no_weight_decay'):
|
||||
skip = model.no_weight_decay()
|
||||
|
||||
skip_keys = {}
|
||||
if hasattr(model, 'no_weight_decay_keywords'):
|
||||
skip_keys = model.no_weight_decay_keywords()
|
||||
|
||||
for n, p in model.named_parameters():
|
||||
ever_set = False
|
||||
|
||||
if p.requires_grad is False:
|
||||
continue
|
||||
|
||||
skip_flag = False
|
||||
if n in skip:
|
||||
print('=> set {} wd to 0'.format(n))
|
||||
without_decay.append(p)
|
||||
skip_flag = True
|
||||
else:
|
||||
for i in skip:
|
||||
if i in n:
|
||||
print('=> set {} wd to 0'.format(n))
|
||||
without_decay.append(p)
|
||||
skip_flag = True
|
||||
|
||||
if skip_flag:
|
||||
continue
|
||||
|
||||
for i in skip_keys:
|
||||
if i in n:
|
||||
print('=> set {} wd to 0'.format(n))
|
||||
|
||||
if skip_flag:
|
||||
continue
|
||||
|
||||
for pp in without_decay_depthwise:
|
||||
if p is pp:
|
||||
if cfg.VERBOSE:
|
||||
print('=> set depthwise({}) wd to 0'.format(n))
|
||||
without_decay.append(p)
|
||||
ever_set = True
|
||||
break
|
||||
|
||||
for pp in without_decay_norm:
|
||||
if p is pp:
|
||||
if cfg.VERBOSE:
|
||||
print('=> set norm({}) wd to 0'.format(n))
|
||||
without_decay.append(p)
|
||||
ever_set = True
|
||||
break
|
||||
|
||||
if (
|
||||
(not ever_set)
|
||||
and 'bias' in without_decay_list
|
||||
and n.endswith('.bias')
|
||||
):
|
||||
if cfg.VERBOSE:
|
||||
print('=> set bias({}) wd to 0'.format(n))
|
||||
without_decay.append(p)
|
||||
elif not ever_set:
|
||||
with_decay.append(p)
|
||||
|
||||
# assert (len(with_decay) + len(without_decay) == len(list(model.parameters())))
|
||||
params = [
|
||||
{'params': with_decay},
|
||||
{'params': without_decay, 'weight_decay': 0.}
|
||||
]
|
||||
return params
|
||||
|
||||
|
||||
def build_optimizer(cfg, model):
|
||||
if cfg.TRAIN.OPTIMIZER == 'timm':
|
||||
args = cfg.TRAIN.OPTIMIZER_ARGS
|
||||
|
||||
print(f'=> usage timm optimizer args: {cfg.TRAIN.OPTIMIZER_ARGS}')
|
||||
optimizer = create_optimizer(args, model)
|
||||
|
||||
return optimizer
|
||||
|
||||
optimizer = None
|
||||
params = set_wd(cfg, model)
|
||||
if cfg.TRAIN.OPTIMIZER == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
params,
|
||||
# filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=cfg.TRAIN.LR,
|
||||
momentum=cfg.TRAIN.MOMENTUM,
|
||||
weight_decay=cfg.TRAIN.WD,
|
||||
nesterov=cfg.TRAIN.NESTEROV
|
||||
)
|
||||
elif cfg.TRAIN.OPTIMIZER == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
params,
|
||||
# filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=cfg.TRAIN.LR,
|
||||
weight_decay=cfg.TRAIN.WD,
|
||||
)
|
||||
elif cfg.TRAIN.OPTIMIZER == 'adamW':
|
||||
optimizer = optim.AdamW(
|
||||
params,
|
||||
lr=cfg.TRAIN.LR,
|
||||
weight_decay=cfg.TRAIN.WD,
|
||||
)
|
||||
elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
params,
|
||||
# filter(lambda p: p.requires_grad, model.parameters()),
|
||||
lr=cfg.TRAIN.LR,
|
||||
momentum=cfg.TRAIN.MOMENTUM,
|
||||
weight_decay=cfg.TRAIN.WD,
|
||||
alpha=cfg.TRAIN.RMSPROP_ALPHA,
|
||||
centered=cfg.TRAIN.RMSPROP_CENTERED
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .build import build_lr_scheduler
|
|
@ -0,0 +1,41 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from timm.scheduler import create_scheduler
|
||||
|
||||
|
||||
def build_lr_scheduler(cfg, optimizer, begin_epoch):
|
||||
if 'METHOD' not in cfg.TRAIN.LR_SCHEDULER:
|
||||
raise ValueError('Please set TRAIN.LR_SCHEDULER.METHOD!')
|
||||
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'MultiStep':
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
cfg.TRAIN.LR_SCHEDULER.MILESTONES,
|
||||
cfg.TRAIN.LR_SCHEDULER.GAMMA,
|
||||
begin_epoch - 1)
|
||||
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'CosineAnnealing':
|
||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
cfg.TRAIN.END_EPOCH,
|
||||
cfg.TRAIN.LR_SCHEDULER.ETA_MIN,
|
||||
begin_epoch - 1
|
||||
)
|
||||
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'CyclicLR':
|
||||
lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
|
||||
optimizer,
|
||||
base_lr=cfg.TRAIN.LR_SCHEDULER.BASE_LR,
|
||||
max_LR=cfg.TRAIN.LR_SCHEDULER.MAX_LR,
|
||||
step_size_up=cfg.TRAIN.LR_SCHEDULER.STEP_SIZE_UP
|
||||
)
|
||||
elif cfg.TRAIN.LR_SCHEDULER.METHOD == 'timm':
|
||||
args = cfg.TRAIN.LR_SCHEDULER.ARGS
|
||||
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||
lr_scheduler.step(begin_epoch)
|
||||
else:
|
||||
raise ValueError('Unknown lr scheduler: {}'.format(
|
||||
cfg.TRAIN.LR_SCHEDULER.METHOD))
|
||||
|
||||
return lr_scheduler
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import pickle
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class Comm(object):
|
||||
def __init__(self, local_rank=0):
|
||||
self.local_rank = 0
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return self._local_rank
|
||||
|
||||
@local_rank.setter
|
||||
def local_rank(self, value):
|
||||
if not dist.is_available():
|
||||
self._local_rank = 0
|
||||
if not dist.is_initialized():
|
||||
self._local_rank = 0
|
||||
self._local_rank = value
|
||||
|
||||
@property
|
||||
def head(self):
|
||||
return 'Rank[{}/{}]'.format(self.rank, self.world_size)
|
||||
|
||||
def is_main_process(self):
|
||||
return self.rank == 0
|
||||
|
||||
def synchronize(self):
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if self.world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
comm = Comm()
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = comm.world_size
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
||||
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that process with rank
|
||||
0 has the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = comm.world_size
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.reduce(values, dst=0)
|
||||
if dist.get_rank() == 0 and average:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import tensorwatch as tw
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from utils.comm import comm
|
||||
from ptflops import get_model_complexity_info
|
||||
|
||||
|
||||
def setup_logger(final_output_dir, rank, phase):
|
||||
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
log_file = '{}_{}_rank{}.txt'.format(phase, time_str, rank)
|
||||
final_log_file = os.path.join(final_output_dir, log_file)
|
||||
head = '%(asctime)-15s:[P:%(process)d]:' + comm.head + ' %(message)s'
|
||||
logging.basicConfig(
|
||||
filename=str(final_log_file), format=head
|
||||
)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
console = logging.StreamHandler()
|
||||
console.setFormatter(
|
||||
logging.Formatter(head)
|
||||
)
|
||||
logging.getLogger('').addHandler(console)
|
||||
|
||||
|
||||
def create_logger(cfg, cfg_name, phase='train'):
|
||||
root_output_dir = Path(cfg.OUTPUT_DIR)
|
||||
dataset = cfg.DATASET.DATASET
|
||||
cfg_name = cfg.NAME
|
||||
|
||||
final_output_dir = root_output_dir / dataset / cfg_name
|
||||
|
||||
print('=> creating {} ...'.format(root_output_dir))
|
||||
root_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
print('=> creating {} ...'.format(final_output_dir))
|
||||
final_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print('=> setup logger ...')
|
||||
setup_logger(final_output_dir, cfg.RANK, phase)
|
||||
|
||||
return str(final_output_dir)
|
||||
|
||||
|
||||
def init_distributed(args):
|
||||
args.num_gpus = int(os.environ["WORLD_SIZE"]) \
|
||||
if "WORLD_SIZE" in os.environ else 1
|
||||
args.distributed = args.num_gpus > 1
|
||||
|
||||
if args.distributed:
|
||||
print("=> init process group start")
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", init_method="env://",
|
||||
timeout=timedelta(minutes=180))
|
||||
comm.local_rank = args.local_rank
|
||||
print("=> init process group end")
|
||||
|
||||
|
||||
def setup_cudnn(config):
|
||||
cudnn.benchmark = config.CUDNN.BENCHMARK
|
||||
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
|
||||
torch.backends.cudnn.enabled = config.CUDNN.ENABLED
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return params/1000000
|
||||
|
||||
|
||||
def summary_model_on_master(model, config, output_dir, copy):
|
||||
if comm.is_main_process():
|
||||
this_dir = os.path.dirname(__file__)
|
||||
shutil.copy2(
|
||||
os.path.join(this_dir, '../models', config.MODEL.NAME + '.py'),
|
||||
output_dir
|
||||
)
|
||||
logging.info('=> {}'.format(model))
|
||||
try:
|
||||
num_params = count_parameters(model)
|
||||
logging.info("Trainable Model Total Parameter: \t%2.1fM" % num_params)
|
||||
except Exception:
|
||||
logging.error('=> error when counting parameters')
|
||||
|
||||
if config.MODEL_SUMMARY:
|
||||
try:
|
||||
logging.info('== model_stats by tensorwatch ==')
|
||||
df = tw.model_stats(
|
||||
model,
|
||||
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
|
||||
)
|
||||
df.to_html(os.path.join(output_dir, 'model_summary.html'))
|
||||
df.to_csv(os.path.join(output_dir, 'model_summary.csv'))
|
||||
msg = '*'*20 + ' Model summary ' + '*'*20
|
||||
logging.info(
|
||||
'\n{msg}\n{summary}\n{msg}'.format(
|
||||
msg=msg, summary=df.iloc[-1]
|
||||
)
|
||||
)
|
||||
logging.info('== model_stats by tensorwatch ==')
|
||||
except Exception:
|
||||
logging.error('=> error when run model_stats')
|
||||
|
||||
try:
|
||||
logging.info('== get_model_complexity_info by ptflops ==')
|
||||
macs, params = get_model_complexity_info(
|
||||
model,
|
||||
(3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]),
|
||||
as_strings=True, print_per_layer_stat=True, verbose=True
|
||||
)
|
||||
logging.info(f'=> FLOPs: {macs:<8}, params: {params:<8}')
|
||||
logging.info('== get_model_complexity_info by ptflops ==')
|
||||
except Exception:
|
||||
logging.error('=> error when run get_model_complexity_info')
|
||||
|
||||
|
||||
def resume_checkpoint(model,
|
||||
optimizer,
|
||||
config,
|
||||
output_dir,
|
||||
in_epoch):
|
||||
best_perf = 0.0
|
||||
begin_epoch_or_step = 0
|
||||
|
||||
checkpoint = os.path.join(output_dir, 'checkpoint.pth')\
|
||||
if not config.TRAIN.CHECKPOINT else config.TRAIN.CHECKPOINT
|
||||
if config.TRAIN.AUTO_RESUME and os.path.exists(checkpoint):
|
||||
logging.info(
|
||||
"=> loading checkpoint '{}'".format(checkpoint)
|
||||
)
|
||||
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
|
||||
best_perf = checkpoint_dict['perf']
|
||||
begin_epoch_or_step = checkpoint_dict['epoch' if in_epoch else 'step']
|
||||
state_dict = checkpoint_dict['state_dict']
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
logging.info(
|
||||
"=> {}: loaded checkpoint '{}' ({}: {})"
|
||||
.format(comm.head,
|
||||
checkpoint,
|
||||
'epoch' if in_epoch else 'step',
|
||||
begin_epoch_or_step)
|
||||
)
|
||||
|
||||
return best_perf, begin_epoch_or_step
|
||||
|
||||
|
||||
def save_checkpoint_on_master(model,
|
||||
*,
|
||||
distributed,
|
||||
model_name,
|
||||
optimizer,
|
||||
output_dir,
|
||||
in_epoch,
|
||||
epoch_or_step,
|
||||
best_perf):
|
||||
if not comm.is_main_process():
|
||||
return
|
||||
|
||||
states = model.module.state_dict() \
|
||||
if distributed else model.state_dict()
|
||||
|
||||
logging.info('=> saving checkpoint to {}'.format(output_dir))
|
||||
save_dict = {
|
||||
'epoch' if in_epoch else 'step': epoch_or_step + 1,
|
||||
'model': model_name,
|
||||
'state_dict': states,
|
||||
'perf': best_perf,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}
|
||||
|
||||
try:
|
||||
torch.save(save_dict, os.path.join(output_dir, 'checkpoint.pth'))
|
||||
except Exception:
|
||||
logging.error('=> error when saving checkpoint!')
|
||||
|
||||
|
||||
def save_model_on_master(model, distributed, out_dir, fname):
|
||||
if not comm.is_main_process():
|
||||
return
|
||||
|
||||
try:
|
||||
fname_full = os.path.join(out_dir, fname)
|
||||
logging.info(f'=> save model to {fname_full}')
|
||||
torch.save(
|
||||
model.module.state_dict() if distributed else model.state_dict(),
|
||||
fname_full
|
||||
)
|
||||
except Exception:
|
||||
logging.error('=> error when saving checkpoint!')
|
||||
|
||||
|
||||
def strip_prefix_if_present(state_dict, prefix):
|
||||
keys = sorted(state_dict.keys())
|
||||
if not all(key.startswith(prefix) for key in keys):
|
||||
return state_dict
|
||||
from collections import OrderedDict
|
||||
stripped_state_dict = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
stripped_state_dict[key.replace(prefix, "")] = value
|
||||
return stripped_state_dict
|
|
@ -0,0 +1,12 @@
|
|||
opencv-python
|
||||
pyyaml
|
||||
json_tricks
|
||||
yacs
|
||||
scikit-learn
|
||||
tensorwatch
|
||||
tensorboardX
|
||||
pandas
|
||||
ptflops
|
||||
timm==0.3.2
|
||||
numpy==1.19.3
|
||||
einops
|
|
@ -0,0 +1,87 @@
|
|||
#!/bin/bash
|
||||
|
||||
train() {
|
||||
python3 -m torch.distributed.launch \
|
||||
--nnodes ${NODE_COUNT} \
|
||||
--node_rank ${RANK} \
|
||||
--master_addr ${MASTER_ADDR} \
|
||||
--master_port ${MASTER_PORT} \
|
||||
--nproc_per_node ${GPUS} \
|
||||
tools/train.py ${EXTRA_ARGS}
|
||||
}
|
||||
|
||||
|
||||
test() {
|
||||
python3 -m torch.distributed.launch \
|
||||
--nnodes ${NODE_COUNT} \
|
||||
--node_rank ${RANK} \
|
||||
--master_addr ${MASTER_ADDR} \
|
||||
--master_port ${MASTER_PORT} \
|
||||
--nproc_per_node ${GPUS} \
|
||||
tools/test.py ${EXTRA_ARGS}
|
||||
}
|
||||
|
||||
############################ Main #############################
|
||||
GPUS=`nvidia-smi -L | wc -l`
|
||||
MASTER_PORT=9000
|
||||
INSTALL_DEPS=false
|
||||
|
||||
while [[ $# -gt 0 ]]
|
||||
do
|
||||
|
||||
key="$1"
|
||||
case $key in
|
||||
-h|--help)
|
||||
echo "Usage: $0 [run_options]"
|
||||
echo "Options:"
|
||||
echo " -g|--gpus <1> - number of gpus to be used"
|
||||
echo " -t|--job-type <train> - job type (train|io|bit_finetune|test)"
|
||||
echo " -p|--port <9000> - master port"
|
||||
echo " -i|--install-deps - If install dependencies (default: False)"
|
||||
exit 1
|
||||
;;
|
||||
-g|--gpus)
|
||||
GPUS=$2
|
||||
shift
|
||||
;;
|
||||
-t|--job-type)
|
||||
JOB_TYPE=$2
|
||||
shift
|
||||
;;
|
||||
-p|--port)
|
||||
MASTER_PORT=$2
|
||||
shift
|
||||
;;
|
||||
-i|--install-deps)
|
||||
INSTALL_DEPS=true
|
||||
;;
|
||||
*)
|
||||
EXTRA_ARGS="$EXTRA_ARGS $1"
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
if $INSTALL_DEPS; then
|
||||
python -m pip install -r requirements.txt --user -q
|
||||
fi
|
||||
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
NODE_COUNT=1
|
||||
echo "job type: ${JOB_TYPE}"
|
||||
echo "rank: ${RANK}"
|
||||
echo "node count: ${NODE_COUNT}"
|
||||
echo "master addr: ${MASTER_ADDR}"
|
||||
|
||||
case $JOB_TYPE in
|
||||
train)
|
||||
train
|
||||
;;
|
||||
test)
|
||||
test
|
||||
;;
|
||||
*)
|
||||
echo "unknown job type"
|
||||
;;
|
||||
esac
|
|
@ -0,0 +1,18 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
lib_path = osp.join(this_dir, '..', 'lib')
|
||||
add_path(lib_path)
|
||||
add_path(osp.join(this_dir, '..'))
|
|
@ -0,0 +1,144 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pickle as pkl
|
||||
import pprint
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import torch.optim
|
||||
from torch.utils.collect_env import get_pretty_env_info
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import _init_paths
|
||||
from config import config
|
||||
from config import update_config
|
||||
from core.function import test
|
||||
from core.loss import build_criterion
|
||||
from dataset import build_dataloader
|
||||
from dataset import RealLabelsImagenet
|
||||
from models import build_model
|
||||
from utils.comm import comm
|
||||
from utils.utils import create_logger
|
||||
from utils.utils import init_distributed
|
||||
from utils.utils import setup_cudnn
|
||||
from utils.utils import summary_model_on_master
|
||||
from utils.utils import strip_prefix_if_present
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Test classification network')
|
||||
|
||||
parser.add_argument('--cfg',
|
||||
help='experiment configure file name',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--port", type=int, default=9000)
|
||||
|
||||
parser.add_argument('opts',
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
init_distributed(args)
|
||||
setup_cudnn(config)
|
||||
|
||||
update_config(config, args)
|
||||
final_output_dir = create_logger(config, args.cfg, 'test')
|
||||
tb_log_dir = final_output_dir
|
||||
|
||||
if comm.is_main_process():
|
||||
logging.info("=> collecting env info (might take some time)")
|
||||
logging.info("\n" + get_pretty_env_info())
|
||||
logging.info(pprint.pformat(args))
|
||||
logging.info(config)
|
||||
logging.info("=> using {} GPUs".format(args.num_gpus))
|
||||
|
||||
output_config_path = os.path.join(final_output_dir, 'config.yaml')
|
||||
logging.info("=> saving config into: {}".format(output_config_path))
|
||||
|
||||
model = build_model(config)
|
||||
model.to(torch.device('cuda'))
|
||||
|
||||
model_file = config.TEST.MODEL_FILE if config.TEST.MODEL_FILE \
|
||||
else os.path.join(final_output_dir, 'model_best.pth')
|
||||
logging.info('=> load model file: {}'.format(model_file))
|
||||
ext = model_file.split('.')[-1]
|
||||
if ext == 'pth':
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
raise ValueError("Unknown model file")
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(torch.device('cuda'))
|
||||
|
||||
writer_dict = {
|
||||
'writer': SummaryWriter(logdir=tb_log_dir),
|
||||
'train_global_steps': 0,
|
||||
'valid_global_steps': 0,
|
||||
}
|
||||
|
||||
summary_model_on_master(model, config, final_output_dir, False)
|
||||
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank
|
||||
)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = build_criterion(config, train=False)
|
||||
criterion.cuda()
|
||||
|
||||
valid_loader = build_dataloader(config, False, args.distributed)
|
||||
real_labels = None
|
||||
if (
|
||||
config.DATASET.DATASET == 'imagenet'
|
||||
and config.DATASET.DATA_FORMAT == 'tsv'
|
||||
and config.TEST.REAL_LABELS
|
||||
):
|
||||
filenames = valid_loader.dataset.get_filenames()
|
||||
real_json = os.path.join(config.DATASET.ROOT, 'real.json')
|
||||
logging.info('=> loading real labels...')
|
||||
real_labels = RealLabelsImagenet(filenames, real_json)
|
||||
|
||||
valid_labels = None
|
||||
if config.TEST.VALID_LABELS:
|
||||
with open(config.TEST.VALID_LABELS, 'r') as f:
|
||||
valid_labels = {
|
||||
int(line.rstrip()) for line in f
|
||||
}
|
||||
valid_labels = [
|
||||
i in valid_labels for i in range(config.MODEL.NUM_CLASSES)
|
||||
]
|
||||
|
||||
logging.info('=> start testing')
|
||||
start = time.time()
|
||||
test(config, valid_loader, model, criterion,
|
||||
final_output_dir, tb_log_dir, writer_dict,
|
||||
args.distributed, real_labels=real_labels,
|
||||
valid_labels=valid_labels)
|
||||
logging.info('=> test duration time: {:.2f}s'.format(time.time()-start))
|
||||
|
||||
writer_dict['writer'].close()
|
||||
logging.info('=> finish testing')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,212 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import torch.optim
|
||||
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
||||
from torch.utils.collect_env import get_pretty_env_info
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import _init_paths
|
||||
from config import config
|
||||
from config import update_config
|
||||
from config import save_config
|
||||
from core.loss import build_criterion
|
||||
from core.function import train_one_epoch, test
|
||||
from dataset import build_dataloader
|
||||
from models import build_model
|
||||
from optim import build_optimizer
|
||||
from scheduler import build_lr_scheduler
|
||||
from utils.comm import comm
|
||||
from utils.utils import create_logger
|
||||
from utils.utils import init_distributed
|
||||
from utils.utils import setup_cudnn
|
||||
from utils.utils import summary_model_on_master
|
||||
from utils.utils import resume_checkpoint
|
||||
from utils.utils import save_checkpoint_on_master
|
||||
from utils.utils import save_model_on_master
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train classification network')
|
||||
|
||||
parser.add_argument('--cfg',
|
||||
help='experiment configure file name',
|
||||
required=True,
|
||||
type=str)
|
||||
|
||||
# distributed training
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--port", type=int, default=9000)
|
||||
|
||||
parser.add_argument('opts',
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
init_distributed(args)
|
||||
setup_cudnn(config)
|
||||
|
||||
update_config(config, args)
|
||||
final_output_dir = create_logger(config, args.cfg, 'train')
|
||||
tb_log_dir = final_output_dir
|
||||
|
||||
if comm.is_main_process():
|
||||
logging.info("=> collecting env info (might take some time)")
|
||||
logging.info("\n" + get_pretty_env_info())
|
||||
logging.info(pprint.pformat(args))
|
||||
logging.info(config)
|
||||
logging.info("=> using {} GPUs".format(args.num_gpus))
|
||||
|
||||
output_config_path = os.path.join(final_output_dir, 'config.yaml')
|
||||
logging.info("=> saving config into: {}".format(output_config_path))
|
||||
save_config(config, output_config_path)
|
||||
|
||||
model = build_model(config)
|
||||
model.to(torch.device('cuda'))
|
||||
|
||||
# copy model file
|
||||
summary_model_on_master(model, config, final_output_dir, True)
|
||||
|
||||
if config.AMP.ENABLED and config.AMP.MEMORY_FORMAT == 'nhwc':
|
||||
logging.info('=> convert memory format to nhwc')
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
writer_dict = {
|
||||
'writer': SummaryWriter(logdir=tb_log_dir),
|
||||
'train_global_steps': 0,
|
||||
'valid_global_steps': 0,
|
||||
}
|
||||
|
||||
best_perf = 0.0
|
||||
best_model = True
|
||||
begin_epoch = config.TRAIN.BEGIN_EPOCH
|
||||
optimizer = build_optimizer(config, model)
|
||||
|
||||
best_perf, begin_epoch = resume_checkpoint(
|
||||
model, optimizer, config, final_output_dir, True
|
||||
)
|
||||
|
||||
train_loader = build_dataloader(config, True, args.distributed)
|
||||
valid_loader = build_dataloader(config, False, args.distributed)
|
||||
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
|
||||
criterion = build_criterion(config)
|
||||
criterion.cuda()
|
||||
criterion_eval = build_criterion(config, train=False)
|
||||
criterion_eval.cuda()
|
||||
|
||||
lr_scheduler = build_lr_scheduler(config, optimizer, begin_epoch)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=config.AMP.ENABLED)
|
||||
|
||||
logging.info('=> start training')
|
||||
for epoch in range(begin_epoch, config.TRAIN.END_EPOCH):
|
||||
head = 'Epoch[{}]:'.format(epoch)
|
||||
logging.info('=> {} epoch start'.format(head))
|
||||
|
||||
start = time.time()
|
||||
if args.distributed:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
# train for one epoch
|
||||
logging.info('=> {} train start'.format(head))
|
||||
with torch.autograd.set_detect_anomaly(config.TRAIN.DETECT_ANOMALY):
|
||||
train_one_epoch(config, train_loader, model, criterion, optimizer,
|
||||
epoch, final_output_dir, tb_log_dir, writer_dict,
|
||||
scaler=scaler)
|
||||
logging.info(
|
||||
'=> {} train end, duration: {:.2f}s'
|
||||
.format(head, time.time()-start)
|
||||
)
|
||||
|
||||
# evaluate on validation set
|
||||
logging.info('=> {} validate start'.format(head))
|
||||
val_start = time.time()
|
||||
|
||||
if epoch >= config.TRAIN.EVAL_BEGIN_EPOCH:
|
||||
perf = test(
|
||||
config, valid_loader, model, criterion_eval,
|
||||
final_output_dir, tb_log_dir, writer_dict,
|
||||
args.distributed
|
||||
)
|
||||
|
||||
best_model = (perf > best_perf)
|
||||
best_perf = perf if best_model else best_perf
|
||||
|
||||
logging.info(
|
||||
'=> {} validate end, duration: {:.2f}s'
|
||||
.format(head, time.time()-val_start)
|
||||
)
|
||||
|
||||
lr_scheduler.step(epoch=epoch+1)
|
||||
if config.TRAIN.LR_SCHEDULER.METHOD == 'timm':
|
||||
lr = lr_scheduler.get_epoch_values(epoch+1)[0]
|
||||
else:
|
||||
lr = lr_scheduler.get_last_lr()[0]
|
||||
logging.info(f'=> lr: {lr}')
|
||||
|
||||
save_checkpoint_on_master(
|
||||
model=model,
|
||||
distributed=args.distributed,
|
||||
model_name=config.MODEL.NAME,
|
||||
optimizer=optimizer,
|
||||
output_dir=final_output_dir,
|
||||
in_epoch=True,
|
||||
epoch_or_step=epoch,
|
||||
best_perf=best_perf,
|
||||
)
|
||||
|
||||
if best_model and comm.is_main_process():
|
||||
save_model_on_master(
|
||||
model, args.distributed, final_output_dir, 'model_best.pth'
|
||||
)
|
||||
|
||||
if config.TRAIN.SAVE_ALL_MODELS and comm.is_main_process():
|
||||
save_model_on_master(
|
||||
model, args.distributed, final_output_dir, f'model_{epoch}.pth'
|
||||
)
|
||||
|
||||
logging.info(
|
||||
'=> {} epoch end, duration : {:.2f}s'
|
||||
.format(head, time.time()-start)
|
||||
)
|
||||
|
||||
save_model_on_master(
|
||||
model, args.distributed, final_output_dir, 'final_state.pth'
|
||||
)
|
||||
|
||||
if config.SWA.ENABLED and comm.is_main_process():
|
||||
save_model_on_master(
|
||||
args.distributed, final_output_dir, 'swa_state.pth'
|
||||
)
|
||||
|
||||
writer_dict['writer'].close()
|
||||
logging.info('=> finish training')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Загрузка…
Ссылка в новой задаче