Initial commit
This commit is contained in:
Родитель
6836c6b570
Коммит
cbc34ddd78
|
@ -0,0 +1,2 @@
|
|||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
|
@ -0,0 +1,138 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
101
README.md
101
README.md
|
@ -1,33 +1,78 @@
|
|||
# Project
|
||||
# CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation)<br>
|
||||
![teaser](imgs/teaser.png)
|
||||
|
||||
> This repo has been populated by an initial template to help get you started. Please
|
||||
> make sure to update the content to build a great experience for community-building.
|
||||
**CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation**<br>
|
||||
**CVPR 2021, oral presentation**<br>
|
||||
[Xingran Zhou](http://xingranzh.github.io/), [Bo Zhang](https://www.microsoft.com/en-us/research/people/zhanbo/), [Ting Zhang](https://www.microsoft.com/en-us/research/people/tinzhan/), [Pan Zhang](https://panzhang0212.github.io/), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](https://www.microsoft.com/en-us/research/people/doch/), [Zhongfei Zhang](https://www.cs.binghamton.edu/~zhongfei/), [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)<br>
|
||||
Paper: https://arxiv.org/pdf/2012.02047.pdf<br>
|
||||
Video: https://youtu.be/aBr1lOjm_FA<br>
|
||||
Slides: https://github.com/xingranzh/CocosNet-v2/blob/master/slides/cocosnet_v2_slides.pdf<br>
|
||||
|
||||
As the maintainer of this project, please make a few updates:
|
||||
Abstract: *We present the full-resolution correspondence learning for cross-domain images, which aids image translation. We adopt a hierarchical strategy that uses the correspondence from coarse level to guide the fine levels. At each hierarchy, the correspondence can be efficiently computed via PatchMatch that iteratively leverages the matchings from the neighborhood. Within each PatchMatch iteration, the ConvGRU module is employed to refine the current correspondence considering not only the matchings of larger context but also the historic estimates. The proposed CoCosNet v2, a GRU-assisted PatchMatch approach, is fully differentiable and highly efficient. When jointly trained with image translation, full-resolution semantic correspondence can be established in an unsupervised manner, which in turn facilitates the exemplar-based image translation. Experiments on diverse translation tasks show that CoCosNet v2 performs considerably better than state-of-the-art literature on producing high-resolution images.*<br>
|
||||
## Installation
|
||||
First please install dependencies for the experiment:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
````
|
||||
We recommend to install Pytorch version after `Pytorch 1.6.0` since we made use of [automatic mixed precision](https://pytorch.org/docs/stable/amp.html) for accelerating. (we used `Pytorch 1.7.0` in our experiments)<br>
|
||||
## Prepare the dataset
|
||||
First download the Deepfashion dataset (high resolution version) from [this link](https://drive.google.com/file/d/1bByKH1ciLXY70Bp8le_AVnjk-Hd4pe_i/view?usp=sharing). Note the file name is `img_highres.zip`. Unzip the file and rename it as `img`.<br>
|
||||
If the password is necessary, please contact [this link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to access the dataset.<br>
|
||||
We use [OpenPose](https://github.com/Hzzone/pytorch-openpose) to estimate pose of DeepFashion(HD). We offer the keypoints detection results used in our experiment in [this link](https://drive.google.com/file/d/1wxrqyb67Xu_IPyZzftLgBPHDTKGQP7Pk/view?usp=sharing). Download and unzip the results file.<br>
|
||||
Since the original resolution of DeepfashionHD is 750x1101, we use a Python script to process the images to the resolution 512x512. You can find the script in `data/preprocess.py`. Note you need to download our train-val split lists `train.txt` and `val.txt` from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing) in this step.<br>
|
||||
Finally create the root folder `DeepfashionHD`, and move the folders `img` and `pose` below it. Now the the directory structure is like:<br>
|
||||
```
|
||||
DeepfashionHD
|
||||
│
|
||||
└─── img
|
||||
│ │
|
||||
│ └─── MEN
|
||||
│ │ │ ...
|
||||
│ │
|
||||
│ └─── WOMEN
|
||||
│ │ ...
|
||||
│
|
||||
└─── pose
|
||||
│ │
|
||||
│ └─── MEN
|
||||
│ │ │ ...
|
||||
│ │
|
||||
│ └─── WOMEN
|
||||
│ │ ...
|
||||
|
||||
- Improving this README.MD file to provide a great experience
|
||||
- Updating SUPPORT.MD with content about this project's support experience
|
||||
- Understanding the security reporting process in SECURITY.MD
|
||||
- Remove this section from the README
|
||||
```
|
||||
## Inference Using Pretrained Model
|
||||
The inference results are saved in the folder `checkpoints/deepfashionHD/test`. Download the pretrained model from [this link](https://drive.google.com/file/d/1ehkrKlf5s1gfpDNXO6AC9SIZMtqs5L3N/view?usp=sharing).<br>
|
||||
Move the models below the folder `checkpoints/deepfashionHD`. Then run the following command.
|
||||
````bash
|
||||
python test.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --PONO --PONO_C --no_flip --batchSize 8 --gpu_ids 0 --netCorr NoVGGHPM --nThreads 16 --nef 32 --amp --display_winsize 512 --iteration_count 5 --load_size 512 --crop_size 512
|
||||
````
|
||||
The inference results are saved in the folder `checkpoints/deepfashionHD/test`.<br>
|
||||
## Training from scratch
|
||||
Make sure you have prepared the DeepfashionHD dataset as the instruction.<br>
|
||||
Download the **pretrained VGG model** from [this link](https://drive.google.com/file/d/1D-z73DOt63BrPTgIxffN6Q4_L9qma9y8/view?usp=sharing), move it to `vgg/` folder. We use this model to calculate training loss.<br>
|
||||
Download the train-val lists from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing), and the retrival pair lists from [this link](https://drive.google.com/drive/folders/1dJU8iq8kFiwq33nWtvj5Ql5rUh9fiXUi?usp=sharing). Note `train.txt` and `val.txt` are our train-val lists. `deepfashion_ref.txt`, `deepfashion_ref_test.txt` and `deepfashion_self_pair.txt` are the paring lists used in our experiment. Download them all and move below the folder `data/`.<br>
|
||||
Run the following command for training from scratch.
|
||||
````bash
|
||||
python train.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --niter 100 --niter_decay 0 --real_reference_probability 0.0 --hard_reference_probability 0.0 --which_perceptual 4_2 --weight_perceptual 0.001 --PONO --PONO_C --vgg_normal_correct --weight_fm_ratio 1.0 --no_flip --video_like --batchSize 16 --gpu_ids 0,1,2,3,4,5,6,7 --netCorr NoVGGHPM --match_kernel 1 --featEnc_kernel 3 --display_freq 500 --print_freq 50 --save_latest_freq 2500 --save_epoch_freq 5 --nThreads 16 --weight_warp_self 500.0 --lr 0.0001 --nef 32 --amp --weight_warp_cycle 1.0 --display_winsize 512 --iteration_count 5 --temperature 0.01 --continue_train --load_size 550 --crop_size 512 --which_epoch 15
|
||||
````
|
||||
Note that `--dataroot` parameter is your DeepFashionHD dataset root, e.g. `dataset/DeepFashionHD`.<br>
|
||||
We use 8 32GB Tesla V100 GPUs to train the network. You can set `batchSize` to 16, 8 or 4 with fewer GPUs and change `gpu_ids`.
|
||||
## Citation
|
||||
If you use this code for your research, please cite our papers.
|
||||
```
|
||||
@inproceedings{zhou2021full,
|
||||
title={CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation},
|
||||
author={Zhou, Xingran and Zhang, Bo and Zhang, Ting and Zhang, Pan and Bao, Jianmin and Chen, Dong and Zhang, Zhongfei and Wen, Fang},
|
||||
booktitle={CVPR},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
## Acknowledgments
|
||||
*This code borrows heavily from [CocosNet](https://github.com/microsoft/CoCosNet).
|
||||
We also thank [SPADE](https://github.com/NVlabs/SPADE) and [RAFT](https://github.com/princeton-vl/RAFT).*
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||
## License
|
||||
The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file.<br>
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import importlib
|
||||
import torch.utils.data
|
||||
from data.base_dataset import BaseDataset
|
||||
|
||||
|
||||
def find_dataset_using_name(dataset_name):
|
||||
dataset_filename = "data." + dataset_name + "_dataset"
|
||||
datasetlib = importlib.import_module(dataset_filename)
|
||||
dataset = None
|
||||
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
||||
for name, cls in datasetlib.__dict__.items():
|
||||
if name.lower() == target_dataset_name.lower() \
|
||||
and issubclass(cls, BaseDataset):
|
||||
dataset = cls
|
||||
if dataset is None:
|
||||
raise ValueError("In %s.py, there should be a subclass of BaseDataset "
|
||||
"with class name that matches %s in lowercase." %
|
||||
(dataset_filename, target_dataset_name))
|
||||
return dataset
|
||||
|
||||
|
||||
def get_option_setter(dataset_name):
|
||||
dataset_class = find_dataset_using_name(dataset_name)
|
||||
return dataset_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_dataloader(opt):
|
||||
dataset = find_dataset_using_name(opt.dataset_mode)
|
||||
instance = dataset()
|
||||
instance.initialize(opt)
|
||||
print("Dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
instance,
|
||||
batch_size=opt.batchSize,
|
||||
shuffle=(opt.phase=='train'),
|
||||
num_workers=int(opt.nThreads),
|
||||
drop_last=(opt.phase=='train')
|
||||
)
|
||||
return dataloader
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
return parser
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.preprocess_mode == 'resize_and_crop':
|
||||
new_h = new_w = opt.load_size
|
||||
elif opt.preprocess_mode == 'scale_width_and_crop':
|
||||
new_w = opt.load_size
|
||||
new_h = opt.load_size * h // w
|
||||
elif opt.preprocess_mode == 'scale_shortside_and_crop':
|
||||
ss, ls = min(w, h), max(w, h) # shortside and longside
|
||||
width_is_shorter = w == ss
|
||||
ls = int(opt.load_size * ls / ss)
|
||||
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
|
||||
transform_list = []
|
||||
if opt.dataset_mode == 'flickr' and method == Image.NEAREST:
|
||||
transform_list.append(transforms.Lambda(lambda img: __add1(img)))
|
||||
if 'resize' in opt.preprocess_mode:
|
||||
osize = [opt.load_size, opt.load_size]
|
||||
transform_list.append(transforms.Resize(osize, interpolation=method))
|
||||
elif 'scale_width' in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
|
||||
elif 'scale_shortside' in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
|
||||
|
||||
if 'crop' in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
||||
|
||||
if opt.preprocess_mode == 'none':
|
||||
base = 32
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.preprocess_mode == 'fixed':
|
||||
w = opt.crop_size
|
||||
h = round(opt.crop_size / opt.aspect_ratio)
|
||||
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
if opt.isTrain and 'rotate' in params.keys():
|
||||
transform_list.append(transforms.Lambda(lambda img: __rotate(img, params['rotate'], method)))
|
||||
|
||||
if toTensor:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def __resize(img, w, h, method=Image.BICUBIC):
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __scale_shortside(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
|
||||
width_is_shorter = ow == ss
|
||||
if (ss == target_width):
|
||||
return img
|
||||
ls = int(target_width * ls / ss)
|
||||
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
|
||||
return img.resize((nw, nh), method)
|
||||
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
||||
|
||||
def __rotate(img, deg, method=Image.BICUBIC):
|
||||
return img.rotate(deg, resample=method)
|
||||
|
||||
|
||||
def __add1(img):
|
||||
return Image.fromarray(np.array(img) + 1)
|
|
@ -0,0 +1,156 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
from data.pix2pix_dataset import Pix2pixDataset
|
||||
from data.base_dataset import get_params, get_transform
|
||||
|
||||
|
||||
class DeepFashionHDDataset(Pix2pixDataset):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
|
||||
parser.set_defaults(preprocess_mode='resize_and_crop')
|
||||
parser.set_defaults(no_pairing_check=True)
|
||||
parser.set_defaults(load_size=550)
|
||||
parser.set_defaults(crop_size=512)
|
||||
parser.set_defaults(label_nc=20)
|
||||
parser.set_defaults(contain_dontcare_label=False)
|
||||
parser.set_defaults(cache_filelist_read=False)
|
||||
parser.set_defaults(cache_filelist_write=False)
|
||||
return parser
|
||||
|
||||
def get_paths(self, opt):
|
||||
root = opt.dataroot
|
||||
if opt.phase == 'train':
|
||||
fd = open(os.path.join('./data/train.txt'))
|
||||
lines = fd.readlines()
|
||||
fd.close()
|
||||
elif opt.phase == 'test':
|
||||
fd = open(os.path.join('./data/val.txt'))
|
||||
lines = fd.readlines()
|
||||
fd.close()
|
||||
image_paths = []
|
||||
label_paths = []
|
||||
for i in range(len(lines)):
|
||||
name = lines[i].strip()
|
||||
image_paths.append(name)
|
||||
label_path = name.replace('img', 'pose').replace('.jpg', '_{}.txt')
|
||||
label_paths.append(os.path.join(label_path))
|
||||
return label_paths, image_paths
|
||||
|
||||
def get_ref_video_like(self, opt):
|
||||
pair_path = './data/deepfashion_self_pair.txt'
|
||||
with open(pair_path) as fd:
|
||||
self_pair = fd.readlines()
|
||||
self_pair = [it.strip() for it in self_pair]
|
||||
self_pair_dict = {}
|
||||
for it in self_pair:
|
||||
items = it.split(',')
|
||||
self_pair_dict[items[0]] = items[1:]
|
||||
ref_path = './data/deepfashion_ref_test.txt' if opt.phase == 'test' else './data/deepfashion_ref.txt'
|
||||
with open(ref_path) as fd:
|
||||
ref = fd.readlines()
|
||||
ref = [it.strip() for it in ref]
|
||||
ref_dict = {}
|
||||
for i in range(len(ref)):
|
||||
items = ref[i].strip().split(',')
|
||||
if items[0] in self_pair_dict.keys():
|
||||
ref_dict[items[0]] = [it for it in self_pair_dict[items[0]]]
|
||||
else:
|
||||
ref_dict[items[0]] = [items[-1]]
|
||||
train_test_folder = ('', '')
|
||||
return ref_dict, train_test_folder
|
||||
|
||||
def get_ref_vgg(self, opt):
|
||||
extra = ''
|
||||
if opt.phase == 'test':
|
||||
extra = '_test'
|
||||
with open('./data/deepfashion_ref{}.txt'.format(extra)) as fd:
|
||||
lines = fd.readlines()
|
||||
ref_dict = {}
|
||||
for i in range(len(lines)):
|
||||
items = lines[i].strip().split(',')
|
||||
key = items[0]
|
||||
if opt.phase == 'test':
|
||||
val = [it for it in items[1:]]
|
||||
else:
|
||||
val = [items[-1]]
|
||||
ref_dict[key] = val
|
||||
train_test_folder = ('', '')
|
||||
return ref_dict, train_test_folder
|
||||
|
||||
def get_ref(self, opt):
|
||||
if opt.video_like:
|
||||
return self.get_ref_video_like(opt)
|
||||
else:
|
||||
return self.get_ref_vgg(opt)
|
||||
|
||||
def get_label_tensor(self, path):
|
||||
candidate = np.loadtxt(path.format('candidate'))
|
||||
subset = np.loadtxt(path.format('subset'))
|
||||
stickwidth = 20
|
||||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
||||
[1, 16], [16, 18], [3, 17], [6, 18]]
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
||||
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
canvas = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
cycle_radius = 20
|
||||
for i in range(18):
|
||||
index = int(subset[i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
cv2.circle(canvas, (int(x), int(y)), cycle_radius, colors[i], thickness=-1)
|
||||
joints = []
|
||||
for i in range(17):
|
||||
index = subset[np.array(limbSeq[i]) - 1]
|
||||
cur_canvas = canvas.copy()
|
||||
if -1 in index:
|
||||
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
||||
continue
|
||||
Y = candidate[index.astype(int), 0]
|
||||
X = candidate[index.astype(int), 1]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
joint = np.zeros_like(cur_canvas[:, :, 0])
|
||||
cv2.fillConvexPoly(joint, polygon, 255)
|
||||
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
||||
joints.append(joint)
|
||||
pose = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)).resize((self.opt.load_size, self.opt.load_size), resample=Image.NEAREST)
|
||||
params = get_params(self.opt, pose.size)
|
||||
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
transform_img = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
|
||||
tensors_dist = 0
|
||||
e = 1
|
||||
for i in range(len(joints)):
|
||||
im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
|
||||
im_dist = np.clip((im_dist/3), 0, 255).astype(np.uint8)
|
||||
tensor_dist = transform_img(Image.fromarray(im_dist))
|
||||
tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
|
||||
e += 1
|
||||
tensor_pose = transform_label(pose)
|
||||
label_tensor = torch.cat((tensor_pose, tensors_dist), dim=0)
|
||||
return label_tensor, params
|
||||
|
||||
def imgpath_to_labelpath(self, path):
|
||||
label_path = path.replace('/img/', '/pose/').replace('.jpg', '_{}.txt')
|
||||
return label_path
|
||||
|
||||
def labelpath_to_imgpath(self, path):
|
||||
img_path = path.replace('/pose/', '/img/').replace('_{}.txt', '.jpg')
|
||||
return img_path
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform
|
||||
|
||||
|
||||
class Pix2pixDataset(BaseDataset):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument('--no_pairing_check', action='store_true', help='If specified, skip sanity check of correct label-image file pairing')
|
||||
return parser
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
label_paths, image_paths = self.get_paths(opt)
|
||||
label_paths = label_paths[:opt.max_dataset_size]
|
||||
image_paths = image_paths[:opt.max_dataset_size]
|
||||
if not opt.no_pairing_check:
|
||||
for path1, path2 in zip(label_paths, image_paths):
|
||||
assert self.paths_match(path1, path2), \
|
||||
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2)
|
||||
self.label_paths = label_paths
|
||||
self.image_paths = image_paths
|
||||
size = len(self.label_paths)
|
||||
self.dataset_size = size
|
||||
self.real_reference_probability = 1 if opt.phase == 'test' else opt.real_reference_probability
|
||||
self.hard_reference_probability = 0 if opt.phase == 'test' else opt.hard_reference_probability
|
||||
self.ref_dict, self.train_test_folder = self.get_ref(opt)
|
||||
|
||||
def get_paths(self, opt):
|
||||
label_paths = []
|
||||
image_paths = []
|
||||
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
|
||||
return label_paths, image_paths
|
||||
|
||||
def paths_match(self, path1, path2):
|
||||
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
|
||||
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
|
||||
return filename1_without_ext == filename2_without_ext
|
||||
|
||||
def get_label_tensor(self, path):
|
||||
label = Image.open(path)
|
||||
params1 = get_params(self.opt, label.size)
|
||||
transform_label = get_transform(self.opt, params1, method=Image.NEAREST, normalize=False)
|
||||
label_tensor = transform_label(label) * 255.0
|
||||
label_tensor[label_tensor == 255] = self.opt.label_nc
|
||||
# 'unknown' is opt.label_nc
|
||||
return label_tensor, params1
|
||||
|
||||
def __getitem__(self, index):
|
||||
# label Image
|
||||
label_path = self.label_paths[index]
|
||||
label_path = os.path.join(self.opt.dataroot, label_path)
|
||||
label_tensor, params1 = self.get_label_tensor(label_path)
|
||||
# input image (real images)
|
||||
image_path = self.image_paths[index]
|
||||
image_path = os.path.join(self.opt.dataroot, image_path)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
transform_image = get_transform(self.opt, params1)
|
||||
image_tensor = transform_image(image)
|
||||
ref_tensor = 0
|
||||
label_ref_tensor = 0
|
||||
random_p = random.random()
|
||||
if random_p < self.real_reference_probability or self.opt.phase == 'test':
|
||||
key = image_path.split('deepfashionHD/')[-1]
|
||||
val = self.ref_dict[key]
|
||||
if random_p < self.hard_reference_probability:
|
||||
#hard reference
|
||||
path_ref = val[1]
|
||||
else:
|
||||
#easy reference
|
||||
path_ref = val[0]
|
||||
if self.opt.dataset_mode == 'deepfashionHD':
|
||||
path_ref = os.path.join(self.opt.dataroot, path_ref)
|
||||
else:
|
||||
path_ref = os.path.dirname(image_path).replace(self.train_test_folder[1], self.train_test_folder[0]) + '/' + path_ref
|
||||
image_ref = Image.open(path_ref).convert('RGB')
|
||||
if self.opt.dataset_mode != 'deepfashionHD':
|
||||
path_ref_label = path_ref.replace('.jpg', '.png')
|
||||
path_ref_label = self.imgpath_to_labelpath(path_ref_label)
|
||||
else:
|
||||
path_ref_label = self.imgpath_to_labelpath(path_ref)
|
||||
label_ref_tensor, params = self.get_label_tensor(path_ref_label)
|
||||
transform_image = get_transform(self.opt, params)
|
||||
ref_tensor = transform_image(image_ref)
|
||||
self_ref_flag = 0.0
|
||||
else:
|
||||
pair = False
|
||||
if self.opt.dataset_mode == 'deepfashionHD' and self.opt.video_like:
|
||||
key = image_path.replace('\\', '/').split('deepfashionHD/')[-1]
|
||||
val = self.ref_dict[key]
|
||||
ref_name = val[0]
|
||||
key_name = key
|
||||
path_ref = os.path.join(self.opt.dataroot, ref_name)
|
||||
image_ref = Image.open(path_ref).convert('RGB')
|
||||
label_ref_path = self.imgpath_to_labelpath(path_ref)
|
||||
label_ref_tensor, params = self.get_label_tensor(label_ref_path)
|
||||
transform_image = get_transform(self.opt, params)
|
||||
ref_tensor = transform_image(image_ref)
|
||||
pair = True
|
||||
if not pair:
|
||||
label_ref_tensor, params = self.get_label_tensor(label_path)
|
||||
transform_image = get_transform(self.opt, params)
|
||||
ref_tensor = transform_image(image)
|
||||
self_ref_flag = 1.0
|
||||
input_dict = {'label': label_tensor,
|
||||
'image': image_tensor,
|
||||
'path': image_path,
|
||||
'self_ref': self_ref_flag,
|
||||
'ref': ref_tensor,
|
||||
'label_ref': label_ref_tensor
|
||||
}
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
||||
|
||||
def get_ref(self, opt):
|
||||
pass
|
||||
|
||||
def imgpath_to_labelpath(self, path):
|
||||
return path
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 3.2 MiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 344 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 2.0 MiB |
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import importlib
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
# Given the option --model [modelname],
|
||||
# the file "models/modelname_model.py"
|
||||
# will be imported.
|
||||
model_filename = "models." + model_name + "_model"
|
||||
modellib = importlib.import_module(model_filename)
|
||||
# In the file, the class called ModelNameModel() will
|
||||
# be instantiated. It has to be a subclass of torch.nn.Module,
|
||||
# and it is case-insensitive.
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, torch.nn.Module):
|
||||
model = cls
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % (type(instance).__name__))
|
||||
return instance
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from util.util import feature_normalize, mse_loss
|
||||
|
||||
|
||||
|
||||
class ContextualLoss_forward(nn.Module):
|
||||
'''
|
||||
input is Al, Bl, channel = 1, range ~ [0, 255]
|
||||
'''
|
||||
|
||||
def __init__(self, opt):
|
||||
super(ContextualLoss_forward, self).__init__()
|
||||
self.opt = opt
|
||||
return None
|
||||
|
||||
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
|
||||
'''
|
||||
X_features&Y_features are are feature vectors or feature 2d array
|
||||
h: bandwidth
|
||||
return the per-sample loss
|
||||
'''
|
||||
batch_size = X_features.shape[0]
|
||||
feature_depth = X_features.shape[1]
|
||||
feature_size = X_features.shape[2]
|
||||
|
||||
# to normalized feature vectors
|
||||
if feature_centering:
|
||||
if self.opt.PONO:
|
||||
X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
|
||||
Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
|
||||
else:
|
||||
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
||||
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
||||
|
||||
# X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
|
||||
# Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
|
||||
|
||||
X_features = feature_normalize(X_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
|
||||
Y_features = feature_normalize(Y_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
|
||||
|
||||
# X_features = F.unfold(
|
||||
# X_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
|
||||
# Y_features = F.unfold(
|
||||
# Y_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
|
||||
|
||||
# conine distance = 1 - similarity
|
||||
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
|
||||
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
|
||||
|
||||
# normalized distance: dij_bar
|
||||
# d_norm = d
|
||||
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-3) # batch_size * feature_size^2 * feature_size^2
|
||||
|
||||
# pairwise affinity
|
||||
w = torch.exp((1 - d_norm) / h)
|
||||
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
|
||||
|
||||
# contextual loss per sample
|
||||
CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
|
||||
loss = -torch.log(CX)
|
||||
|
||||
# contextual loss per batch
|
||||
# loss = torch.mean(loss)
|
||||
return loss
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
from models.networks.base_network import BaseNetwork
|
||||
from models.networks.loss import *
|
||||
from models.networks.discriminator import *
|
||||
from models.networks.generator import *
|
||||
from models.networks.ContextualLoss import *
|
||||
from models.networks.correspondence import *
|
||||
from models.networks.ops import *
|
||||
import util.util as util
|
||||
|
||||
|
||||
def find_network_using_name(target_network_name, filename, add=True):
|
||||
target_class_name = target_network_name + filename if add else target_network_name
|
||||
module_name = 'models.networks.' + filename
|
||||
network = util.find_class_in_module(target_class_name, module_name)
|
||||
assert issubclass(network, BaseNetwork), \
|
||||
"Class %s should be a subclass of BaseNetwork" % network
|
||||
return network
|
||||
|
||||
|
||||
def modify_commandline_options(parser, is_train):
|
||||
opt, _ = parser.parse_known_args()
|
||||
netG_cls = find_network_using_name(opt.netG, 'generator')
|
||||
parser = netG_cls.modify_commandline_options(parser, is_train)
|
||||
if is_train:
|
||||
netD_cls = find_network_using_name(opt.netD, 'discriminator')
|
||||
parser = netD_cls.modify_commandline_options(parser, is_train)
|
||||
return parser
|
||||
|
||||
|
||||
def create_network(cls, opt):
|
||||
net = cls(opt)
|
||||
net.print_network()
|
||||
if len(opt.gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.cuda()
|
||||
net.init_weights(opt.init_type, opt.init_variance)
|
||||
return net
|
||||
|
||||
|
||||
def define_G(opt):
|
||||
netG_cls = find_network_using_name(opt.netG, 'generator')
|
||||
return create_network(netG_cls, opt)
|
||||
|
||||
|
||||
def define_D(opt):
|
||||
netD_cls = find_network_using_name(opt.netD, 'discriminator')
|
||||
return create_network(netD_cls, opt)
|
||||
|
||||
def define_Corr(opt):
|
||||
netCoor_cls = find_network_using_name(opt.netCorr, 'correspondence')
|
||||
return create_network(netCoor_cls, opt)
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
from models.networks.normalization import SPADE
|
||||
from util.util import vgg_preprocess
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.relu = nn.PReLU()
|
||||
self.model = nn.Sequential(
|
||||
nn.ReflectionPad2d(padding),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
self.relu,
|
||||
nn.ReflectionPad2d(padding),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.relu(x + self.model(x))
|
||||
return out
|
||||
|
||||
|
||||
class SPADEResnetBlock(nn.Module):
|
||||
def __init__(self, fin, fout, opt, use_se=False, dilation=1):
|
||||
super().__init__()
|
||||
# Attributes
|
||||
self.learned_shortcut = (fin != fout)
|
||||
fmiddle = min(fin, fout)
|
||||
self.opt = opt
|
||||
self.pad_type = 'nozero'
|
||||
self.use_se = use_se
|
||||
# create conv layers
|
||||
if self.pad_type != 'zero':
|
||||
self.pad = nn.ReflectionPad2d(dilation)
|
||||
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=0, dilation=dilation)
|
||||
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=0, dilation=dilation)
|
||||
else:
|
||||
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
||||
# apply spectral norm if specified
|
||||
if 'spectral' in opt.norm_G:
|
||||
self.conv_0 = spectral_norm(self.conv_0)
|
||||
self.conv_1 = spectral_norm(self.conv_1)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = spectral_norm(self.conv_s)
|
||||
# define normalization layers
|
||||
spade_config_str = opt.norm_G.replace('spectral', '')
|
||||
if 'spade_ic' in opt:
|
||||
ic = opt.spade_ic
|
||||
else:
|
||||
ic = 4*3+opt.label_nc
|
||||
self.norm_0 = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
|
||||
self.norm_1 = SPADE(spade_config_str, fmiddle, ic, PONO=opt.PONO)
|
||||
if self.learned_shortcut:
|
||||
self.norm_s = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
|
||||
|
||||
def forward(self, x, seg1):
|
||||
x_s = self.shortcut(x, seg1)
|
||||
if self.pad_type != 'zero':
|
||||
dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, seg1))))
|
||||
dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, seg1))))
|
||||
else:
|
||||
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
||||
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
||||
out = x_s + dx
|
||||
return out
|
||||
|
||||
def shortcut(self, x, seg1):
|
||||
if self.learned_shortcut:
|
||||
x_s = self.conv_s(self.norm_s(x, seg1))
|
||||
else:
|
||||
x_s = x
|
||||
return x_s
|
||||
|
||||
def actvn(self, x):
|
||||
return F.leaky_relu(x, 2e-1)
|
||||
|
||||
|
||||
class VGG19_feature_color_torchversion(nn.Module):
|
||||
"""
|
||||
NOTE: there is no need to pre-process the input
|
||||
input tensor should range in [0,1]
|
||||
"""
|
||||
def __init__(self, pool='max', vgg_normal_correct=False, ic=3):
|
||||
super(VGG19_feature_color_torchversion, self).__init__()
|
||||
self.vgg_normal_correct = vgg_normal_correct
|
||||
|
||||
self.conv1_1 = nn.Conv2d(ic, 64, kernel_size=3, padding=1)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
|
||||
if pool == 'max':
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
elif pool == 'avg':
|
||||
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
|
||||
def forward(self, x, out_keys, preprocess=True):
|
||||
'''
|
||||
NOTE: input tensor should range in [0,1]
|
||||
'''
|
||||
out = {}
|
||||
if preprocess:
|
||||
x = vgg_preprocess(x, vgg_normal_correct=self.vgg_normal_correct)
|
||||
out['r11'] = F.relu(self.conv1_1(x))
|
||||
out['r12'] = F.relu(self.conv1_2(out['r11']))
|
||||
out['p1'] = self.pool1(out['r12'])
|
||||
out['r21'] = F.relu(self.conv2_1(out['p1']))
|
||||
out['r22'] = F.relu(self.conv2_2(out['r21']))
|
||||
out['p2'] = self.pool2(out['r22'])
|
||||
out['r31'] = F.relu(self.conv3_1(out['p2']))
|
||||
out['r32'] = F.relu(self.conv3_2(out['r31']))
|
||||
out['r33'] = F.relu(self.conv3_3(out['r32']))
|
||||
out['r34'] = F.relu(self.conv3_4(out['r33']))
|
||||
out['p3'] = self.pool3(out['r34'])
|
||||
out['r41'] = F.relu(self.conv4_1(out['p3']))
|
||||
out['r42'] = F.relu(self.conv4_2(out['r41']))
|
||||
out['r43'] = F.relu(self.conv4_3(out['r42']))
|
||||
out['r44'] = F.relu(self.conv4_4(out['r43']))
|
||||
out['p4'] = self.pool4(out['r44'])
|
||||
out['r51'] = F.relu(self.conv5_1(out['p4']))
|
||||
out['r52'] = F.relu(self.conv5_2(out['r51']))
|
||||
out['r53'] = F.relu(self.conv5_3(out['r52']))
|
||||
out['r54'] = F.relu(self.conv5_4(out['r53']))
|
||||
out['p5'] = self.pool5(out['r54'])
|
||||
return [out[key] for key in out_keys]
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
return parser
|
||||
|
||||
def print_network(self):
|
||||
if isinstance(self, list):
|
||||
self = self[0]
|
||||
num_params = 0
|
||||
for param in self.parameters():
|
||||
num_params += param.numel()
|
||||
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
||||
'To see the architecture, do print(network).'
|
||||
% (type(self).__name__, num_params / 1000000))
|
||||
|
||||
def init_weights(self, init_type='normal', gain=0.02):
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('BatchNorm2d') != -1:
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
init.normal_(m.weight.data, 1.0, gain)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'xavier_uniform':
|
||||
init.xavier_uniform_(m.weight.data, gain=1.0)
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'none': # uses pytorch's default init method
|
||||
m.reset_parameters()
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
self.apply(init_func)
|
||||
# propagate to children
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights(init_type, gain)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=32, hidden_dim=64):
|
||||
super(FlowHead, self).__init__()
|
||||
candidate_num = 16
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2*candidate_num, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
num = x.size()[1]
|
||||
delta_offset_x, delta_offset_y = torch.split(x, [num//2, num//2], dim=1)
|
||||
return delta_offset_x, delta_offset_y
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=32, input_dim=64):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
candidate_num = 16
|
||||
self.convc1 = nn.Conv2d(candidate_num, 64, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2*candidate_num, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+64, 64-2*candidate_num, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.encoder = BasicMotionEncoder()
|
||||
self.gru = SepConvGRU()
|
||||
self.flow_head = FlowHead()
|
||||
|
||||
def forward(self, net, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = motion_features
|
||||
net = self.gru(net, inp)
|
||||
delta_offset_x, delta_offset_y = self.flow_head(net)
|
||||
return net, delta_offset_x, delta_offset_y
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import util.util as util
|
||||
from models.networks.base_network import BaseNetwork
|
||||
from models.networks.architecture import ResidualBlock
|
||||
from models.networks.normalization import get_nonspade_norm_layer
|
||||
from models.networks.architecture import SPADEResnetBlock
|
||||
"""patch match"""
|
||||
from models.networks.patch_match import PatchMatchGRU
|
||||
from models.networks.ops import *
|
||||
|
||||
|
||||
def match_kernel_and_pono_c(feature, match_kernel, PONO_C, eps=1e-10):
|
||||
b, c, h, w = feature.size()
|
||||
if match_kernel == 1:
|
||||
feature = feature.view(b, c, -1)
|
||||
else:
|
||||
feature = F.unfold(feature, kernel_size=match_kernel, padding=int(match_kernel//2))
|
||||
dim_mean = 1 if PONO_C else -1
|
||||
feature = feature - feature.mean(dim=dim_mean, keepdim=True)
|
||||
feature_norm = torch.norm(feature, 2, 1, keepdim=True) + eps
|
||||
feature = torch.div(feature, feature_norm)
|
||||
return feature.view(b, -1, h, w)
|
||||
|
||||
|
||||
"""512x512"""
|
||||
class AdaptiveFeatureGenerator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
kw = opt.featEnc_kernel
|
||||
pw = int((kw-1)//2)
|
||||
nf = opt.nef
|
||||
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
|
||||
self.layer1 = norm_layer(nn.Conv2d(opt.spade_ic, nf, 3, stride=1, padding=pw))
|
||||
self.layer2 = nn.Sequential(
|
||||
norm_layer(nn.Conv2d(nf * 1, nf * 2, 3, 1, 1)),
|
||||
ResidualBlock(nf * 2, nf * 2),
|
||||
)
|
||||
self.layer3 = nn.Sequential(
|
||||
norm_layer(nn.Conv2d(nf * 2, nf * 4, kw, stride=2, padding=pw)),
|
||||
ResidualBlock(nf * 4, nf * 4),
|
||||
)
|
||||
self.layer4 = nn.Sequential(
|
||||
norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
|
||||
ResidualBlock(nf * 4, nf * 4),
|
||||
)
|
||||
self.layer5 = nn.Sequential(
|
||||
norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
|
||||
ResidualBlock(nf * 4, nf * 4),
|
||||
)
|
||||
self.head_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
|
||||
self.G_middle_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
|
||||
self.G_middle_1 = SPADEResnetBlock(nf * 4, nf * 2, opt)
|
||||
self.G_middle_2 = SPADEResnetBlock(nf * 2, nf * 1, opt)
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
|
||||
def forward(self, input, seg):
|
||||
# 512
|
||||
x1 = self.layer1(input)
|
||||
# 512
|
||||
x2 = self.layer2(self.actvn(x1))
|
||||
# 256
|
||||
x3 = self.layer3(self.actvn(x2))
|
||||
# 128
|
||||
x4 = self.layer4(self.actvn(x3))
|
||||
# 64
|
||||
x5 = self.layer5(self.actvn(x4))
|
||||
# bottleneck
|
||||
x6 = self.head_0(x5, seg)
|
||||
# 128
|
||||
x7 = self.G_middle_0(self.up(x6) + x4, seg)
|
||||
# 256
|
||||
x8 = self.G_middle_1(self.up(x7) + x3, seg)
|
||||
# 512
|
||||
x9 = self.G_middle_2(self.up(x8) + x2, seg)
|
||||
return [x6, x7, x8, x9]
|
||||
|
||||
def actvn(self, x):
|
||||
return F.leaky_relu(x, 2e-1)
|
||||
|
||||
|
||||
class NoVGGHPMCorrespondence(BaseNetwork):
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
super().__init__()
|
||||
opt.spade_ic = opt.semantic_nc
|
||||
self.adaptive_model_seg = AdaptiveFeatureGenerator(opt)
|
||||
opt.spade_ic = 3 + opt.semantic_nc
|
||||
self.adaptive_model_img = AdaptiveFeatureGenerator(opt)
|
||||
del opt.spade_ic
|
||||
self.batch_size = opt.batchSize
|
||||
"""512x512"""
|
||||
feature_channel = opt.nef
|
||||
self.phi_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.phi_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.phi_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.phi_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.theta_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.theta_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.theta_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.theta_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
|
||||
self.patch_match = PatchMatchGRU(opt)
|
||||
|
||||
"""512x512"""
|
||||
def multi_scale_patch_match(self, f1, f2, ref, hierarchical_scale, pre=None, real_img=None):
|
||||
if hierarchical_scale == 0:
|
||||
y_cycle = None
|
||||
scale = 64
|
||||
batch_size, channel, feature_height, feature_width = f1.size()
|
||||
ref = F.avg_pool2d(ref, 8, stride=8)
|
||||
ref = ref.view(batch_size, 3, scale * scale)
|
||||
f1 = f1.view(batch_size, channel, scale * scale)
|
||||
f2 = f2.view(batch_size, channel, scale * scale)
|
||||
matmul_result = torch.matmul(f1.permute(0, 2, 1), f2)/self.opt.temperature
|
||||
mat = F.softmax(matmul_result, dim=-1)
|
||||
y = torch.matmul(mat, ref.permute(0, 2, 1))
|
||||
if self.opt.phase is 'train' and self.opt.weight_warp_cycle > 0:
|
||||
mat_cycle = F.softmax(matmul_result.transpose(1, 2), dim=-1)
|
||||
y_cycle = torch.matmul(mat_cycle, y)
|
||||
y_cycle = y_cycle.permute(0, 2, 1).view(batch_size, 3, scale, scale)
|
||||
y = y.permute(0, 2, 1).view(batch_size, 3, scale, scale)
|
||||
return mat, y, y_cycle
|
||||
if hierarchical_scale == 1:
|
||||
scale = 128
|
||||
with torch.no_grad():
|
||||
batch_size, channel, feature_height, feature_width = f1.size()
|
||||
topk_num = 1
|
||||
search_window = 4
|
||||
centering = 1
|
||||
dilation = 2
|
||||
total_candidate_num = topk_num * (search_window ** 2)
|
||||
topk_inds = torch.topk(pre, topk_num, dim=-1)[-1]
|
||||
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
|
||||
offset_x, offset_y = inds_to_offset(inds)
|
||||
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
|
||||
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
|
||||
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
|
||||
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
|
||||
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
|
||||
offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
|
||||
ref = F.avg_pool2d(ref, 4, stride=4)
|
||||
ref = ref.view(batch_size, 3, scale * scale)
|
||||
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
|
||||
y = y.view(batch_size, 3, scale, scale)
|
||||
return mat, y
|
||||
if hierarchical_scale == 2:
|
||||
scale = 256
|
||||
with torch.no_grad():
|
||||
batch_size, channel, feature_height, feature_width = f1.size()
|
||||
topk_num = 1
|
||||
search_window = 4
|
||||
centering = 1
|
||||
dilation = 2
|
||||
total_candidate_num = topk_num * (search_window ** 2)
|
||||
topk_inds = pre[:, :, :topk_num]
|
||||
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
|
||||
offset_x, offset_y = inds_to_offset(inds)
|
||||
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
|
||||
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
|
||||
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
|
||||
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
|
||||
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
|
||||
offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
|
||||
ref = F.avg_pool2d(ref, 2, stride=2)
|
||||
ref = ref.view(batch_size, 3, scale * scale)
|
||||
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
|
||||
y = y.view(batch_size, 3, scale, scale)
|
||||
return mat, y
|
||||
if hierarchical_scale == 3:
|
||||
scale = 512
|
||||
with torch.no_grad():
|
||||
batch_size, channel, feature_height, feature_width = f1.size()
|
||||
topk_num = 1
|
||||
search_window = 4
|
||||
centering = 1
|
||||
dilation = 2
|
||||
total_candidate_num = topk_num * (search_window ** 2)
|
||||
topk_inds = pre[:, :, :topk_num]
|
||||
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
|
||||
offset_x, offset_y = inds_to_offset(inds)
|
||||
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
|
||||
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
|
||||
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
|
||||
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
|
||||
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
|
||||
offset_y_up = F.interpolate((2 * offset_y + dx), scale_factor=2)
|
||||
ref = ref.view(batch_size, 3, scale * scale)
|
||||
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
|
||||
y = y.view(batch_size, 3, scale, scale)
|
||||
return mat, y
|
||||
|
||||
def forward(self, ref_img, real_img, seg_map, ref_seg_map):
|
||||
corr_out = {}
|
||||
seg_input = seg_map
|
||||
adaptive_feature_seg = self.adaptive_model_seg(seg_input, seg_input)
|
||||
ref_input = torch.cat((ref_img, ref_seg_map), dim=1)
|
||||
adaptive_feature_img = self.adaptive_model_img(ref_input, ref_input)
|
||||
for i in range(len(adaptive_feature_seg)):
|
||||
adaptive_feature_seg[i] = util.feature_normalize(adaptive_feature_seg[i])
|
||||
adaptive_feature_img[i] = util.feature_normalize(adaptive_feature_img[i])
|
||||
if self.opt.isTrain and self.opt.weight_novgg_featpair > 0:
|
||||
real_input = torch.cat((real_img, seg_map), dim=1)
|
||||
adaptive_feature_img_pair = self.adaptive_model_img(real_input, real_input)
|
||||
loss_novgg_featpair = 0
|
||||
weights = [1.0, 1.0, 1.0, 1.0]
|
||||
for i in range(len(adaptive_feature_img_pair)):
|
||||
adaptive_feature_img_pair[i] = util.feature_normalize(adaptive_feature_img_pair[i])
|
||||
loss_novgg_featpair += F.l1_loss(adaptive_feature_seg[i], adaptive_feature_img_pair[i]) * weights[i]
|
||||
corr_out['loss_novgg_featpair'] = loss_novgg_featpair * self.opt.weight_novgg_featpair
|
||||
cont_features = adaptive_feature_seg
|
||||
ref_features = adaptive_feature_img
|
||||
theta = []
|
||||
phi = []
|
||||
"""512x512"""
|
||||
theta.append(match_kernel_and_pono_c(self.theta_0(cont_features[0]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
theta.append(match_kernel_and_pono_c(self.theta_1(cont_features[1]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
theta.append(match_kernel_and_pono_c(self.theta_2(cont_features[2]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
theta.append(match_kernel_and_pono_c(self.theta_3(cont_features[3]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
phi.append(match_kernel_and_pono_c(self.phi_0(ref_features[0]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
phi.append(match_kernel_and_pono_c(self.phi_1(ref_features[1]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
phi.append(match_kernel_and_pono_c(self.phi_2(ref_features[2]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
phi.append(match_kernel_and_pono_c(self.phi_3(ref_features[3]), self.opt.match_kernel, self.opt.PONO_C))
|
||||
ref = ref_img
|
||||
ys = []
|
||||
m = None
|
||||
for i in range(len(theta)):
|
||||
if i == 0:
|
||||
m, y, y_cycle = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
|
||||
if y_cycle is not None:
|
||||
corr_out['warp_cycle'] = y_cycle
|
||||
else:
|
||||
m, y = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
|
||||
ys.append(y)
|
||||
corr_out['warp_out'] = ys
|
||||
return corr_out
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.networks.base_network import BaseNetwork
|
||||
from models.networks.normalization import get_nonspade_norm_layer
|
||||
import util.util as util
|
||||
|
||||
|
||||
class MultiscaleDiscriminator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument('--netD_subarch', type=str, default='n_layer',
|
||||
help='architecture of each discriminator')
|
||||
parser.add_argument('--num_D', type=int, default=2,
|
||||
help='number of discriminators to be used in multiscale')
|
||||
opt, _ = parser.parse_known_args()
|
||||
# define properties of each discriminator of the multiscale discriminator
|
||||
subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', \
|
||||
'models.networks.discriminator')
|
||||
subnetD.modify_commandline_options(parser, is_train)
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
for i in range(opt.num_D):
|
||||
subnetD = self.create_single_discriminator(opt)
|
||||
self.add_module('discriminator_%d' % i, subnetD)
|
||||
|
||||
def create_single_discriminator(self, opt):
|
||||
subarch = opt.netD_subarch
|
||||
if subarch == 'n_layer':
|
||||
netD = NLayerDiscriminator(opt)
|
||||
else:
|
||||
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
|
||||
return netD
|
||||
|
||||
def downsample(self, input):
|
||||
return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, input):
|
||||
result = []
|
||||
get_intermediate_features = not self.opt.no_ganFeat_loss
|
||||
for name, D in self.named_children():
|
||||
out = D(input)
|
||||
if not get_intermediate_features:
|
||||
out = [out]
|
||||
result.append(out)
|
||||
input = self.downsample(input)
|
||||
return result
|
||||
|
||||
|
||||
class NLayerDiscriminator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator')
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
kw = 4
|
||||
padw = int((kw - 1.0) / 2)
|
||||
nf = opt.ndf
|
||||
input_nc = self.compute_D_input_nc(opt)
|
||||
norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
|
||||
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, False)]]
|
||||
for n in range(1, opt.n_layers_D):
|
||||
nf_prev = nf
|
||||
nf = min(nf * 2, 512)
|
||||
stride = 1 if n == opt.n_layers_D - 1 else 2
|
||||
if n == opt.n_layers_D - 1:
|
||||
dec = []
|
||||
nc_dec = nf_prev
|
||||
for _ in range(opt.n_layers_D - 1):
|
||||
dec += [nn.Upsample(scale_factor=2),
|
||||
norm_layer(nn.Conv2d(nc_dec, int(nc_dec//2), kernel_size=3, stride=1, padding=1)),
|
||||
nn.LeakyReLU(0.2, False)]
|
||||
nc_dec = int(nc_dec // 2)
|
||||
dec += [nn.Conv2d(nc_dec, opt.semantic_nc, kernel_size=3, stride=1, padding=1)]
|
||||
self.dec = nn.Sequential(*dec)
|
||||
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
|
||||
nn.LeakyReLU(0.2, False)]]
|
||||
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
||||
for n in range(len(sequence)):
|
||||
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
||||
|
||||
def compute_D_input_nc(self, opt):
|
||||
input_nc = opt.label_nc + opt.output_nc
|
||||
if opt.contain_dontcare_label:
|
||||
input_nc += 1
|
||||
return input_nc
|
||||
|
||||
def forward(self, input):
|
||||
results = [input]
|
||||
seg = None
|
||||
cam_logit = None
|
||||
for name, submodel in self.named_children():
|
||||
if 'model' not in name:
|
||||
continue
|
||||
x = results[-1]
|
||||
intermediate_output = submodel(x)
|
||||
results.append(intermediate_output)
|
||||
get_intermediate_features = not self.opt.no_ganFeat_loss
|
||||
if get_intermediate_features:
|
||||
retu = results[1:]
|
||||
else:
|
||||
retu = results[-1]
|
||||
return retu
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from models.networks.base_network import BaseNetwork
|
||||
from models.networks.architecture import SPADEResnetBlock
|
||||
|
||||
|
||||
class SPADEGenerator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
nf = opt.ngf
|
||||
self.sw, self.sh = self.compute_latent_vector_size(opt)
|
||||
ic = 4*3+opt.label_nc
|
||||
self.fc = nn.Conv2d(ic, 8 * nf, 3, padding=1)
|
||||
self.head_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
|
||||
self.G_middle_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
|
||||
self.G_middle_1 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
|
||||
self.up_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
|
||||
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
|
||||
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
|
||||
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
|
||||
final_nc = nf
|
||||
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
||||
self.up = nn.Upsample(scale_factor=2)
|
||||
|
||||
def compute_latent_vector_size(self, opt):
|
||||
num_up_layers = 5
|
||||
sw = opt.crop_size // (2**num_up_layers)
|
||||
sh = round(sw / opt.aspect_ratio)
|
||||
return sw, sh
|
||||
|
||||
def forward(self, input, warp_out=None):
|
||||
seg = torch.cat((F.interpolate(warp_out[0], size=(512, 512)), F.interpolate(warp_out[1], size=(512, 512)), F.interpolate(warp_out[2], size=(512, 512)), warp_out[3], input), dim=1)
|
||||
x = F.interpolate(seg, size=(self.sh, self.sw))
|
||||
x = self.fc(x)
|
||||
x = self.head_0(x, seg)
|
||||
x = self.up(x)
|
||||
x = self.G_middle_0(x, seg)
|
||||
x = self.G_middle_1(x, seg)
|
||||
x = self.up(x)
|
||||
x = self.up_0(x, seg)
|
||||
x = self.up(x)
|
||||
x = self.up_1(x, seg)
|
||||
x = self.up(x)
|
||||
x = self.up_2(x, seg)
|
||||
x = self.up(x)
|
||||
x = self.up_3(x, seg)
|
||||
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
||||
x = torch.tanh(x)
|
||||
return x
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor, opt=None):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_tensor = None
|
||||
self.fake_label_tensor = None
|
||||
self.zero_tensor = None
|
||||
self.Tensor = tensor
|
||||
self.gan_mode = gan_mode
|
||||
self.opt = opt
|
||||
if gan_mode == 'ls':
|
||||
pass
|
||||
elif gan_mode == 'original':
|
||||
pass
|
||||
elif gan_mode == 'w':
|
||||
pass
|
||||
elif gan_mode == 'hinge':
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
if target_is_real:
|
||||
if self.real_label_tensor is None:
|
||||
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
|
||||
self.real_label_tensor.requires_grad_(False)
|
||||
return self.real_label_tensor.expand_as(input)
|
||||
else:
|
||||
if self.fake_label_tensor is None:
|
||||
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
|
||||
self.fake_label_tensor.requires_grad_(False)
|
||||
return self.fake_label_tensor.expand_as(input)
|
||||
|
||||
def get_zero_tensor(self, input):
|
||||
if self.zero_tensor is None:
|
||||
self.zero_tensor = self.Tensor(1).fill_(0)
|
||||
self.zero_tensor.requires_grad_(False)
|
||||
return self.zero_tensor.expand_as(input).type_as(input)
|
||||
|
||||
def loss(self, input, target_is_real, for_discriminator=True):
|
||||
if self.gan_mode == 'original': # cross entropy loss
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
|
||||
return loss
|
||||
elif self.gan_mode == 'ls':
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
return F.mse_loss(input, target_tensor)
|
||||
elif self.gan_mode == 'hinge':
|
||||
if for_discriminator:
|
||||
if target_is_real:
|
||||
minval = torch.min(input - 1, self.get_zero_tensor(input))
|
||||
loss = -torch.mean(minval)
|
||||
else:
|
||||
minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
||||
loss = -torch.mean(minval)
|
||||
else:
|
||||
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
||||
loss = -torch.mean(input)
|
||||
return loss
|
||||
else:
|
||||
# wgan
|
||||
if target_is_real:
|
||||
return -input.mean()
|
||||
else:
|
||||
return input.mean()
|
||||
|
||||
def __call__(self, input, target_is_real, for_discriminator=True):
|
||||
if isinstance(input, list):
|
||||
loss = 0
|
||||
for pred_i in input:
|
||||
if isinstance(pred_i, list):
|
||||
pred_i = pred_i[-1]
|
||||
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
|
||||
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
||||
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
||||
loss += new_loss
|
||||
return loss / len(input)
|
||||
else:
|
||||
return self.loss(input, target_is_real, for_discriminator)
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
|
||||
|
||||
def get_nonspade_norm_layer(opt, norm_type='instance'):
|
||||
def get_out_channel(layer):
|
||||
if hasattr(layer, 'out_channels'):
|
||||
return getattr(layer, 'out_channels')
|
||||
return layer.weight.size(0)
|
||||
def add_norm_layer(layer):
|
||||
nonlocal norm_type
|
||||
if norm_type.startswith('spectral'):
|
||||
layer = spectral_norm(layer)
|
||||
subnorm_type = norm_type[len('spectral'):]
|
||||
else:
|
||||
subnorm_type =norm_type
|
||||
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
||||
return layer
|
||||
if getattr(layer, 'bias', None) is not None:
|
||||
delattr(layer, 'bias')
|
||||
layer.register_parameter('bias', None)
|
||||
if subnorm_type == 'batch':
|
||||
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
||||
elif subnorm_type == 'sync_batch':
|
||||
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
||||
elif subnorm_type == 'instance':
|
||||
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
||||
else:
|
||||
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
|
||||
return nn.Sequential(layer, norm_layer)
|
||||
return add_norm_layer
|
||||
|
||||
|
||||
def PositionalNorm2d(x, epsilon=1e-8):
|
||||
# x: B*C*W*H normalize in C dim
|
||||
mean = x.mean(dim=1, keepdim=True)
|
||||
std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
|
||||
output = (x - mean) / std
|
||||
return output
|
||||
|
||||
|
||||
class SPADE(nn.Module):
|
||||
def __init__(self, config_text, norm_nc, label_nc, PONO=False):
|
||||
super().__init__()
|
||||
assert config_text.startswith('spade')
|
||||
parsed = re.search('spade(\D+)(\d)x\d', config_text)
|
||||
param_free_norm_type = str(parsed.group(1))
|
||||
ks = int(parsed.group(2))
|
||||
self.pad_type = 'nozero'
|
||||
if PONO:
|
||||
self.param_free_norm = PositionalNorm2d
|
||||
elif param_free_norm_type == 'instance':
|
||||
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
||||
elif param_free_norm_type == 'syncbatch':
|
||||
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
|
||||
elif param_free_norm_type == 'batch':
|
||||
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
|
||||
else:
|
||||
raise ValueError('%s is not a recognized param-free norm type in SPADE' % param_free_norm_type)
|
||||
nhidden = 128
|
||||
pw = ks // 2
|
||||
if self.pad_type != 'zero':
|
||||
self.mlp_shared = nn.Sequential(
|
||||
nn.ReflectionPad2d(pw),
|
||||
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pad = nn.ReflectionPad2d(pw)
|
||||
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
|
||||
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
|
||||
|
||||
else:
|
||||
self.mlp_shared = nn.Sequential(
|
||||
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
||||
|
||||
def forward(self, x, segmap):
|
||||
normalized = self.param_free_norm(x)
|
||||
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
||||
actv = self.mlp_shared(segmap)
|
||||
if self.pad_type != 'zero':
|
||||
gamma = self.mlp_gamma(self.pad(actv))
|
||||
beta = self.mlp_beta(self.pad(actv))
|
||||
else:
|
||||
gamma = self.mlp_gamma(actv)
|
||||
beta = self.mlp_beta(actv)
|
||||
out = normalized * (1 + gamma) + beta
|
||||
return out
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def convert_1d_to_2d(index, base=64):
|
||||
x = index // base
|
||||
y = index % base
|
||||
return x,y
|
||||
|
||||
|
||||
def convert_2d_to_1d(x, y, base=64):
|
||||
return x*base+y
|
||||
|
||||
|
||||
def batch_meshgrid(shape, device):
|
||||
batch_size, _, height, width = shape
|
||||
x_range = torch.arange(0.0, width, device=device)
|
||||
y_range = torch.arange(0.0, height, device=device)
|
||||
x_coordinate, y_coordinate = torch.meshgrid(x_range, y_range)
|
||||
x_coordinate = x_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
|
||||
y_coordinate = y_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
|
||||
return x_coordinate, y_coordinate
|
||||
|
||||
|
||||
def inds_to_offset(inds):
|
||||
"""
|
||||
inds: b x number x h x w
|
||||
"""
|
||||
shape = inds.size()
|
||||
device = inds.device
|
||||
x_coordinate, y_coordinate = batch_meshgrid(shape, device)
|
||||
batch_size, _, height, width = shape
|
||||
x = inds // width
|
||||
y = inds % width
|
||||
return x - x_coordinate, y - y_coordinate
|
||||
|
||||
|
||||
def offset_to_inds(offset_x, offset_y):
|
||||
shape = offset_x.size()
|
||||
device = offset_x.device
|
||||
x_coordinate, y_coordinate = batch_meshgrid(shape, device)
|
||||
h, w = offset_x.size()[2:]
|
||||
x = torch.clamp(x_coordinate + offset_x, 0, h-1)
|
||||
y = torch.clamp(y_coordinate + offset_y, 0, w-1)
|
||||
return x * offset_x.size()[3] + y
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import random
|
||||
|
||||
from models.networks.convgru import BasicUpdateBlock
|
||||
from models.networks.ops import *
|
||||
|
||||
|
||||
"""patch match"""
|
||||
class Evaluate(nn.Module):
|
||||
def __init__(self, temperature):
|
||||
super().__init__()
|
||||
self.filter_size = 3
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, left_features, right_features, offset_x, offset_y):
|
||||
device = left_features.get_device()
|
||||
batch_size, num, height, width = offset_x.size()
|
||||
channel = left_features.size()[1]
|
||||
matching_inds = offset_to_inds(offset_x, offset_y)
|
||||
matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
|
||||
base_batch = torch.arange(batch_size).to(device).long() * (height * width)
|
||||
base_batch = base_batch.view(-1, 1, 1)
|
||||
matching_inds_add_base = matching_inds + base_batch
|
||||
right_features_view = right_features
|
||||
match_cost = []
|
||||
# using A[:, idx]
|
||||
for i in range(matching_inds_add_base.size()[-1]):
|
||||
idx = matching_inds_add_base[:, :, i]
|
||||
idx = idx.contiguous().view(-1)
|
||||
right_features_select = right_features_view[:, idx]
|
||||
right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
|
||||
match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
|
||||
match_cost.append(match_cost_i)
|
||||
match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
|
||||
match_cost = F.softmax(match_cost, dim=-1)
|
||||
match_cost_topk, match_cost_topk_indices = torch.topk(match_cost, num//self.filter_size, dim=-1)
|
||||
matching_inds = torch.gather(matching_inds, -1, match_cost_topk_indices)
|
||||
matching_inds = matching_inds.permute(0, 2, 1).view(batch_size, -1, height, width).float()
|
||||
offset_x, offset_y = inds_to_offset(matching_inds)
|
||||
corr = match_cost_topk.permute(0, 2, 1)
|
||||
return offset_x, offset_y, corr
|
||||
|
||||
|
||||
class PropagationFaster(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, offset_x, offset_y, propagation_type="horizontal"):
|
||||
device = offset_x.get_device()
|
||||
self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device)
|
||||
self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device)
|
||||
if propagation_type is "horizontal":
|
||||
offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3),
|
||||
offset_x,
|
||||
torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
|
||||
|
||||
offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3),
|
||||
offset_y,
|
||||
torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
|
||||
|
||||
else:
|
||||
offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2),
|
||||
offset_x,
|
||||
torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
|
||||
|
||||
offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2),
|
||||
offset_y,
|
||||
torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
|
||||
return offset_x, offset_y
|
||||
|
||||
|
||||
class PatchMatchOnce(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.propagation = PropagationFaster()
|
||||
self.evaluate = Evaluate(opt.temperature)
|
||||
|
||||
def forward(self, left_features, right_features, offset_x, offset_y):
|
||||
prob = random.random()
|
||||
if prob < 0.5:
|
||||
offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
|
||||
offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
|
||||
offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
|
||||
offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
|
||||
else:
|
||||
offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
|
||||
offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
|
||||
offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
|
||||
offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
|
||||
return offset_x, offset_y, corr
|
||||
|
||||
|
||||
class PatchMatchGRU(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.patch_match_one_step = PatchMatchOnce(opt)
|
||||
self.temperature = opt.temperature
|
||||
self.iters = opt.iteration_count
|
||||
input_dim = opt.nef
|
||||
hidden_dim = 32
|
||||
norm = nn.InstanceNorm2d(hidden_dim, affine=False)
|
||||
relu = nn.ReLU(inplace=True)
|
||||
"""
|
||||
concat left and right features
|
||||
"""
|
||||
self.initial_layer = nn.Sequential(
|
||||
nn.Conv2d(input_dim*2, hidden_dim, kernel_size=3, padding=1, stride=1),
|
||||
norm,
|
||||
relu,
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
||||
norm,
|
||||
relu,
|
||||
)
|
||||
self.refine_net = BasicUpdateBlock()
|
||||
|
||||
def forward(self, left_features, right_features, right_input, initial_offset_x, initial_offset_y):
|
||||
device = left_features.get_device()
|
||||
batch_size, channel, height, width = left_features.size()
|
||||
num = initial_offset_x.size()[1]
|
||||
initial_input = torch.cat((left_features, right_features), dim=1)
|
||||
hidden = self.initial_layer(initial_input)
|
||||
left_features = left_features.view(batch_size, -1, height * width)
|
||||
right_features = right_features.view(batch_size, -1, height * width)
|
||||
right_features_view = right_features.transpose(0, 1).contiguous().view(channel, -1)
|
||||
with torch.no_grad():
|
||||
offset_x, offset_y = initial_offset_x, initial_offset_y
|
||||
for it in range(self.iters):
|
||||
with torch.no_grad():
|
||||
offset_x, offset_y, corr = self.patch_match_one_step(left_features, right_features_view, offset_x, offset_y)
|
||||
"""GRU refinement"""
|
||||
flow = torch.cat((offset_x, offset_y), dim=1)
|
||||
corr = corr.view(batch_size, -1, height, width)
|
||||
hidden, delta_offset_x, delta_offset_y = self.refine_net(hidden, corr, flow)
|
||||
offset_x = offset_x + delta_offset_x
|
||||
offset_y = offset_y + delta_offset_y
|
||||
with torch.no_grad():
|
||||
matching_inds = offset_to_inds(offset_x, offset_y)
|
||||
matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
|
||||
base_batch = torch.arange(batch_size).to(device).long() * (height * width)
|
||||
base_batch = base_batch.view(-1, 1, 1)
|
||||
matching_inds_plus_base = matching_inds + base_batch
|
||||
match_cost = []
|
||||
# using A[:, idx]
|
||||
for i in range(matching_inds_plus_base.size()[-1]):
|
||||
idx = matching_inds_plus_base[:, :, i]
|
||||
idx = idx.contiguous().view(-1)
|
||||
right_features_select = right_features_view[:, idx]
|
||||
right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
|
||||
match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
|
||||
match_cost.append(match_cost_i)
|
||||
match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
|
||||
match_cost = F.softmax(match_cost, dim=-1)
|
||||
right_input_view = right_input.transpose(0, 1).contiguous().view(right_input.size()[1], -1)
|
||||
warp = torch.zeros_like(right_input)
|
||||
# using A[:, idx]
|
||||
for i in range(match_cost.size()[-1]):
|
||||
idx = matching_inds_plus_base[:, :, i]
|
||||
idx = idx.contiguous().view(-1)
|
||||
right_input_select = right_input_view[:, idx]
|
||||
right_input_select = right_input_select.view(right_input.size()[1], batch_size, -1).transpose(0, 1)
|
||||
warp = warp + right_input_select * match_cost[:, :, i].unsqueeze(dim=1)
|
||||
return matching_inds, warp
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import models.networks as networks
|
||||
import util.util as util
|
||||
import itertools
|
||||
try:
|
||||
from torch.cuda.amp import autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class Pix2PixModel(torch.nn.Module):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
networks.modify_commandline_options(parser, is_train)
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
|
||||
else torch.FloatTensor
|
||||
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
|
||||
else torch.ByteTensor
|
||||
self.net = torch.nn.ModuleDict(self.initialize_networks(opt))
|
||||
# set loss functions
|
||||
if opt.isTrain:
|
||||
# vgg network
|
||||
self.vggnet_fix = networks.architecture.VGG19_feature_color_torchversion(vgg_normal_correct=opt.vgg_normal_correct)
|
||||
self.vggnet_fix.load_state_dict(torch.load('vgg/vgg19_conv.pth'))
|
||||
self.vggnet_fix.eval()
|
||||
for param in self.vggnet_fix.parameters():
|
||||
param.requires_grad = False
|
||||
self.vggnet_fix.to(self.opt.gpu_ids[0])
|
||||
# contextual loss
|
||||
self.contextual_forward_loss = networks.ContextualLoss_forward(opt)
|
||||
# GAN loss
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
|
||||
# L1 loss
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
# L2 loss
|
||||
self.MSE_loss = torch.nn.MSELoss()
|
||||
# setting which layer is used in the perceptual loss
|
||||
if opt.which_perceptual == '5_2':
|
||||
self.perceptual_layer = -1
|
||||
elif opt.which_perceptual == '4_2':
|
||||
self.perceptual_layer = -2
|
||||
|
||||
def forward(self, data, mode, GforD=None):
|
||||
input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, )
|
||||
generated_out = {}
|
||||
|
||||
if mode == 'generator':
|
||||
g_loss, generated_out = self.compute_generator_loss(input_label, \
|
||||
input_semantics, real_image, ref_label, \
|
||||
ref_semantics, ref_image, self_ref)
|
||||
out = {}
|
||||
out['fake_image'] = generated_out['fake_image']
|
||||
out['input_semantics'] = input_semantics
|
||||
out['ref_semantics'] = ref_semantics
|
||||
out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out']
|
||||
out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg']
|
||||
out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img']
|
||||
out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle']
|
||||
return g_loss, out
|
||||
|
||||
elif mode == 'discriminator':
|
||||
d_loss = self.compute_discriminator_loss(input_semantics, \
|
||||
real_image, GforD, label=input_label)
|
||||
return d_loss
|
||||
|
||||
elif mode == 'inference':
|
||||
out = {}
|
||||
with torch.no_grad():
|
||||
out = self.inference(input_semantics, ref_semantics=ref_semantics, \
|
||||
ref_image=ref_image, self_ref=self_ref, \
|
||||
real_image=real_image)
|
||||
out['input_semantics'] = input_semantics
|
||||
out['ref_semantics'] = ref_semantics
|
||||
return out
|
||||
|
||||
else:
|
||||
raise ValueError("|mode| is invalid")
|
||||
|
||||
def create_optimizers(self, opt):
|
||||
if opt.no_TTUR:
|
||||
beta1, beta2 = opt.beta1, opt.beta2
|
||||
G_lr, D_lr = opt.lr, opt.lr
|
||||
else:
|
||||
beta1, beta2 = 0, 0.9
|
||||
G_lr, D_lr = opt.lr / 2, opt.lr * 2
|
||||
optimizer_G = torch.optim.Adam(itertools.chain(self.net['netG'].parameters(), \
|
||||
self.net['netCorr'].parameters()), lr=G_lr, betas=(beta1, beta2), eps=1e-3)
|
||||
optimizer_D = torch.optim.Adam(itertools.chain(self.net['netD'].parameters()), \
|
||||
lr=D_lr, betas=(beta1, beta2))
|
||||
return optimizer_G, optimizer_D
|
||||
|
||||
def save(self, epoch):
|
||||
util.save_network(self.net['netG'], 'G', epoch, self.opt)
|
||||
util.save_network(self.net['netD'], 'D', epoch, self.opt)
|
||||
util.save_network(self.net['netCorr'], 'Corr', epoch, self.opt)
|
||||
|
||||
def initialize_networks(self, opt):
|
||||
net = {}
|
||||
net['netG'] = networks.define_G(opt)
|
||||
net['netD'] = networks.define_D(opt) if opt.isTrain else None
|
||||
net['netCorr'] = networks.define_Corr(opt)
|
||||
if not opt.isTrain or opt.continue_train:
|
||||
net['netCorr'] = util.load_network(net['netCorr'], 'Corr', opt.which_epoch, opt)
|
||||
net['netG'] = util.load_network(net['netG'], 'G', opt.which_epoch, opt)
|
||||
if opt.isTrain:
|
||||
net['netD'] = util.load_network(net['netD'], 'D', opt.which_epoch, opt)
|
||||
return net
|
||||
|
||||
def preprocess_input(self, data):
|
||||
if self.use_gpu():
|
||||
for k in data.keys():
|
||||
try:
|
||||
data[k] = data[k].cuda()
|
||||
except:
|
||||
continue
|
||||
label = data['label'][:,:3,:,:].float()
|
||||
label_ref = data['label_ref'][:,:3,:,:].float()
|
||||
input_semantics = data['label'].float()
|
||||
ref_semantics = data['label_ref'].float()
|
||||
image = data['image']
|
||||
ref = data['ref']
|
||||
self_ref = data['self_ref']
|
||||
return label, input_semantics, image, self_ref, ref, label_ref, ref_semantics
|
||||
|
||||
def get_ctx_loss(self, source, target):
|
||||
contextual_style5_1 = torch.mean(self.contextual_forward_loss(source[-1], target[-1].detach())) * 8
|
||||
contextual_style4_1 = torch.mean(self.contextual_forward_loss(source[-2], target[-2].detach())) * 4
|
||||
contextual_style3_1 = torch.mean(self.contextual_forward_loss(F.avg_pool2d(source[-3], 2), F.avg_pool2d(target[-3].detach(), 2))) * 2
|
||||
return contextual_style5_1 + contextual_style4_1 + contextual_style3_1
|
||||
|
||||
def compute_generator_loss(self, input_label, input_semantics, real_image, ref_label=None, ref_semantics=None, ref_image=None, self_ref=None):
|
||||
G_losses = {}
|
||||
generate_out = self.generate_fake(input_semantics, real_image, ref_semantics=ref_semantics, ref_image=ref_image, self_ref=self_ref)
|
||||
generate_out['fake_image'] = generate_out['fake_image'].float()
|
||||
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
||||
sample_weights = self_ref/(sum(self_ref)+1e-5)
|
||||
sample_weights = sample_weights.view(-1, 1, 1, 1)
|
||||
"""domain align"""
|
||||
if 'loss_novgg_featpair' in generate_out and generate_out['loss_novgg_featpair'] is not None:
|
||||
G_losses['no_vgg_feat'] = generate_out['loss_novgg_featpair']
|
||||
"""warping cycle"""
|
||||
if self.opt.weight_warp_cycle > 0:
|
||||
warp_cycle = generate_out['warp_cycle']
|
||||
scale_factor = ref_image.size()[-1] // warp_cycle.size()[-1]
|
||||
ref = F.avg_pool2d(ref_image, scale_factor, stride=scale_factor)
|
||||
G_losses['G_warp_cycle'] = F.l1_loss(warp_cycle, ref) * self.opt.weight_warp_cycle
|
||||
"""warping loss"""
|
||||
if self.opt.weight_warp_self > 0:
|
||||
"""512x512"""
|
||||
warp1, warp2, warp3, warp4 = generate_out['warp_out']
|
||||
G_losses['G_warp_self'] = \
|
||||
torch.mean(F.l1_loss(warp4, real_image, reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
|
||||
torch.mean(F.l1_loss(warp3, F.avg_pool2d(real_image, 2, stride=2), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
|
||||
torch.mean(F.l1_loss(warp2, F.avg_pool2d(real_image, 4, stride=4), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
|
||||
torch.mean(F.l1_loss(warp1, F.avg_pool2d(real_image, 8, stride=8), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0
|
||||
"""gan loss"""
|
||||
pred_fake, pred_real = self.discriminate(input_semantics, generate_out['fake_image'], real_image)
|
||||
G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.weight_gan
|
||||
if not self.opt.no_ganFeat_loss:
|
||||
num_D = len(pred_fake)
|
||||
GAN_Feat_loss = 0.0
|
||||
for i in range(num_D):
|
||||
# for each discriminator
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
# for each layer output
|
||||
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
|
||||
GAN_Feat_loss += unweighted_loss * self.opt.weight_ganFeat / num_D
|
||||
G_losses['GAN_Feat'] = GAN_Feat_loss
|
||||
"""feature matching loss"""
|
||||
fake_features = self.vggnet_fix(generate_out['fake_image'], ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
|
||||
loss = 0
|
||||
for i in range(len(generate_out['real_features'])):
|
||||
loss += weights[i] * util.weighted_l1_loss(fake_features[i], generate_out['real_features'][i].detach(), sample_weights)
|
||||
G_losses['fm'] = loss * self.opt.weight_vgg * self.opt.weight_fm_ratio
|
||||
"""perceptual loss"""
|
||||
feat_loss = util.mse_loss(fake_features[self.perceptual_layer], generate_out['real_features'][self.perceptual_layer].detach())
|
||||
G_losses['perc'] = feat_loss * self.opt.weight_perceptual
|
||||
"""contextual loss"""
|
||||
G_losses['contextual'] = self.get_ctx_loss(fake_features, generate_out['ref_features']) * self.opt.weight_vgg * self.opt.weight_contextual
|
||||
return G_losses, generate_out
|
||||
|
||||
def compute_discriminator_loss(self, input_semantics, real_image, GforD, label=None):
|
||||
D_losses = {}
|
||||
with torch.no_grad():
|
||||
fake_image = GforD['fake_image'].detach()
|
||||
fake_image.requires_grad_()
|
||||
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
|
||||
D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) * self.opt.weight_gan
|
||||
D_losses['D_real'] = self.criterionGAN(pred_real, True, for_discriminator=True) * self.opt.weight_gan
|
||||
return D_losses
|
||||
|
||||
def encode_z(self, real_image):
|
||||
mu, logvar = self.net['netE'](real_image)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
def generate_fake(self, input_semantics, real_image, ref_semantics=None, ref_image=None, self_ref=None):
|
||||
generate_out = {}
|
||||
generate_out['ref_features'] = self.vggnet_fix(ref_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
|
||||
generate_out['real_features'] = self.vggnet_fix(real_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
|
||||
with autocast(enabled=self.opt.amp):
|
||||
corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
|
||||
generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
|
||||
generate_out = {**generate_out, **corr_out}
|
||||
return generate_out
|
||||
|
||||
def inference(self, input_semantics, ref_semantics=None, ref_image=None, self_ref=None, real_image=None):
|
||||
generate_out = {}
|
||||
with autocast(enabled=self.opt.amp):
|
||||
corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
|
||||
generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
|
||||
generate_out = {**generate_out, **corr_out}
|
||||
return generate_out
|
||||
|
||||
def discriminate(self, input_semantics, fake_image, real_image):
|
||||
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
|
||||
real_concat = torch.cat([input_semantics, real_image], dim=1)
|
||||
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
|
||||
with autocast(enabled=self.opt.amp):
|
||||
discriminator_out = self.net['netD'](fake_and_real)
|
||||
pred_fake, pred_real = self.divide_pred(discriminator_out)
|
||||
return pred_fake, pred_real
|
||||
|
||||
def divide_pred(self, pred):
|
||||
if type(pred) == list:
|
||||
fake = []
|
||||
real = []
|
||||
for p in pred:
|
||||
fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
|
||||
real.append([tensor[tensor.size(0) // 2:] for tensor in p])
|
||||
else:
|
||||
fake = pred[:pred.size(0) // 2]
|
||||
real = pred[pred.size(0) // 2:]
|
||||
return fake, real
|
||||
|
||||
def use_gpu(self):
|
||||
return len(self.opt.gpu_ids) > 0
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import argparse
|
||||
import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
import models
|
||||
import data
|
||||
from util import util
|
||||
|
||||
|
||||
class BaseOptions():
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, parser):
|
||||
# experiment specifics
|
||||
parser.add_argument('--name', type=str, default='deepfashionHD', help='name of the experiment. It decides where to store samples and models')
|
||||
parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
||||
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
|
||||
parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
||||
parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
||||
parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
|
||||
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
# input/output sizes
|
||||
parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
|
||||
parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
|
||||
parser.add_argument('--load_size', type=int, default=256, help='Scale images to this size. The final image will be cropped to --crop_size.')
|
||||
parser.add_argument('--crop_size', type=int, default=256, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
|
||||
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
|
||||
parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
|
||||
parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
|
||||
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
|
||||
# for setting inputs
|
||||
parser.add_argument('--dataroot', type=str, default='dataset/deepfashionHD')
|
||||
parser.add_argument('--dataset_mode', type=str, default='deepfashionHD')
|
||||
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
|
||||
parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
|
||||
parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster')
|
||||
parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache')
|
||||
# for displays
|
||||
parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
|
||||
# for generator
|
||||
parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)')
|
||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
|
||||
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
|
||||
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
|
||||
# for feature encoder
|
||||
parser.add_argument('--netCorr', type=str, default='NoVGGHPM')
|
||||
parser.add_argument('--nef', type=int, default=32, help='# of gen filters in first conv layer')
|
||||
# for instance-wise features
|
||||
parser.add_argument('--CBN_intype', type=str, default='warp_mask', help='type of CBN input for framework, warp/mask/warp_mask')
|
||||
parser.add_argument('--match_kernel', type=int, default=1, help='correspondence matrix match kernel size')
|
||||
parser.add_argument('--featEnc_kernel', type=int, default=3, help='kernel size in domain adaptor')
|
||||
parser.add_argument('--PONO', action='store_true', help='use positional normalization ')
|
||||
parser.add_argument('--PONO_C', action='store_true', help='use C normalization in corr module')
|
||||
parser.add_argument('--vgg_normal_correct', action='store_true', help='if true, correct vgg normalization and replace vgg FM model with ctx model')
|
||||
parser.add_argument('--use_coordconv', action='store_true', help='if true, use coordconv in CorrNet')
|
||||
parser.add_argument('--video_like', action='store_true', help='useful in deepfashion')
|
||||
parser.add_argument('--amp', action='store_true', help='use torch.cuda.amp')
|
||||
parser.add_argument('--temperature', type=float, default=0.01)
|
||||
parser.add_argument('--iteration_count', type=int, default=5)
|
||||
self.initialized = True
|
||||
return parser
|
||||
|
||||
def gather_options(self):
|
||||
# initialize parser with basic options
|
||||
if not self.initialized:
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = self.initialize(parser)
|
||||
# get the basic options
|
||||
opt, unknown = parser.parse_known_args()
|
||||
# modify model-related parser options
|
||||
model_name = opt.model
|
||||
model_option_setter = models.get_option_setter(model_name)
|
||||
parser = model_option_setter(parser, self.isTrain)
|
||||
# modify dataset-related parser options
|
||||
dataset_mode = opt.dataset_mode
|
||||
dataset_option_setter = data.get_option_setter(dataset_mode)
|
||||
parser = dataset_option_setter(parser, self.isTrain)
|
||||
opt, unknown = parser.parse_known_args()
|
||||
# if there is opt_file, load it.
|
||||
# The previous default options will be overwritten
|
||||
if opt.load_from_opt_file:
|
||||
parser = self.update_options_from_file(parser, opt)
|
||||
opt = parser.parse_args()
|
||||
self.parser = parser
|
||||
return opt
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
|
||||
def option_file_path(self, opt, makedir=False):
|
||||
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
if makedir:
|
||||
util.mkdirs(expr_dir)
|
||||
file_name = os.path.join(expr_dir, 'opt')
|
||||
return file_name
|
||||
|
||||
def save_options(self, opt):
|
||||
file_name = self.option_file_path(opt, makedir=True)
|
||||
with open(file_name + '.txt', 'wt') as opt_file:
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
|
||||
with open(file_name + '.pkl', 'wb') as opt_file:
|
||||
pickle.dump(opt, opt_file)
|
||||
|
||||
def update_options_from_file(self, parser, opt):
|
||||
new_opt = self.load_options(opt)
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
if hasattr(new_opt, k) and v != getattr(new_opt, k):
|
||||
new_val = getattr(new_opt, k)
|
||||
parser.set_defaults(**{k: new_val})
|
||||
return parser
|
||||
|
||||
def load_options(self, opt):
|
||||
file_name = self.option_file_path(opt, makedir=False)
|
||||
new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
|
||||
return new_opt
|
||||
|
||||
def parse(self, save=False):
|
||||
# gather options from base, train, dataset, model
|
||||
opt = self.gather_options()
|
||||
# train or test
|
||||
opt.isTrain = self.isTrain
|
||||
self.print_options(opt)
|
||||
if opt.isTrain:
|
||||
self.save_options(opt)
|
||||
# Set semantic_nc based on the option.
|
||||
# This will be convenient in many places
|
||||
opt.semantic_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
|
||||
str_ids = opt.gpu_ids.split(',')
|
||||
opt.gpu_ids = list(range(len(str_ids)))
|
||||
seed = 1234
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if len(opt.gpu_ids) > 0:
|
||||
torch.cuda.set_device(opt.gpu_ids[0])
|
||||
self.opt = opt
|
||||
return self.opt
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
def initialize(self, parser):
|
||||
BaseOptions.initialize(self, parser)
|
||||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
||||
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
|
||||
parser.add_argument('--save_per_img', action='store_true', help='if specified, save per image')
|
||||
parser.add_argument('--show_corr', action='store_true', help='if specified, save bilinear upsample correspondence')
|
||||
parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256)
|
||||
parser.set_defaults(serial_batches=True)
|
||||
parser.set_defaults(no_flip=True)
|
||||
parser.set_defaults(phase='test')
|
||||
self.isTrain = False
|
||||
return parser
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TrainOptions(BaseOptions):
|
||||
def initialize(self, parser):
|
||||
BaseOptions.initialize(self, parser)
|
||||
# for displays
|
||||
parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen')
|
||||
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
||||
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
||||
parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
|
||||
# for training
|
||||
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
||||
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
|
||||
parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
|
||||
parser.add_argument('--optimizer', type=str, default='adam')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
||||
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
||||
parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')
|
||||
# for discriminators
|
||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
|
||||
parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
|
||||
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
|
||||
parser.add_argument('--real_reference_probability', type=float, default=0.0, help='self-supervised training probability')
|
||||
parser.add_argument('--hard_reference_probability', type=float, default=0.0, help='hard reference training probability')
|
||||
# training loss weights
|
||||
parser.add_argument('--weight_warp_self', type=float, default=0.0, help='push warp self to ref')
|
||||
parser.add_argument('--weight_warp_cycle', type=float, default=0.0, help='push warp cycle to ref')
|
||||
parser.add_argument('--weight_novgg_featpair', type=float, default=10.0, help='in no vgg setting, use pair feat loss in domain adaptation')
|
||||
parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
|
||||
parser.add_argument('--weight_gan', type=float, default=10.0, help='weight of all loss in stage1')
|
||||
parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
|
||||
parser.add_argument('--weight_ganFeat', type=float, default=10.0, help='weight for feature matching loss')
|
||||
parser.add_argument('--which_perceptual', type=str, default='4_2', help='relu5_2 or relu4_2')
|
||||
parser.add_argument('--weight_perceptual', type=float, default=0.001)
|
||||
parser.add_argument('--weight_vgg', type=float, default=10.0, help='weight for vgg loss')
|
||||
parser.add_argument('--weight_contextual', type=float, default=1.0, help='ctx loss weight')
|
||||
parser.add_argument('--weight_fm_ratio', type=float, default=1.0, help='vgg fm loss weight comp with ctx loss')
|
||||
self.isTrain = True
|
||||
return parser
|
|
@ -0,0 +1,10 @@
|
|||
torch==1.7.0
|
||||
torchvision
|
||||
matplotlib
|
||||
pillow
|
||||
imageio
|
||||
numpy
|
||||
pandas
|
||||
scipy
|
||||
scikit-image
|
||||
opencv-python
|
Двоичный файл не отображается.
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
import os
|
||||
import imageio
|
||||
import numpy as np
|
||||
import data
|
||||
from util.util import mkdir
|
||||
from options.test_options import TestOptions
|
||||
from models.pix2pix_model import Pix2PixModel
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
dataloader = data.create_dataloader(opt)
|
||||
model = Pix2PixModel(opt)
|
||||
if len(opt.gpu_ids) > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
|
||||
else:
|
||||
model.to(opt.gpu_ids[0])
|
||||
model.eval()
|
||||
save_root = os.path.join(opt.checkpoints_dir, opt.name, 'test')
|
||||
mkdir(save_root)
|
||||
for i, data_i in enumerate(dataloader):
|
||||
print('{} / {}'.format(i, len(dataloader)))
|
||||
if i * opt.batchSize >= opt.how_many:
|
||||
break
|
||||
imgs_num = data_i['label'].shape[0]
|
||||
out = model(data_i, mode='inference')
|
||||
if opt.save_per_img:
|
||||
try:
|
||||
for it in range(imgs_num):
|
||||
save_name = os.path.join(save_root, '%08d_%04d.png' % (i, it))
|
||||
save_image(out['fake_image'][it:it+1], save_name, padding=0, normalize=True)
|
||||
except OSError as err:
|
||||
print(err)
|
||||
else:
|
||||
label = data_i['label'][:,:3,:,:]
|
||||
imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), out['fake_image'].data.cpu()), 0)
|
||||
try:
|
||||
save_name = os.path.join(save_root, '%08d.png' % i)
|
||||
save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
|
||||
except OSError as err:
|
||||
print(err)
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
from options.train_options import TrainOptions
|
||||
import data
|
||||
from util.iter_counter import IterationCounter
|
||||
from util.util import print_current_errors
|
||||
from util.util import mkdir
|
||||
from trainers.pix2pix_trainer import Pix2PixTrainer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# parse options
|
||||
opt = TrainOptions().parse()
|
||||
# print options to help debugging
|
||||
print(' '.join(sys.argv))
|
||||
dataloader = data.create_dataloader(opt)
|
||||
len_dataloader = len(dataloader)
|
||||
# create tool for counting iterations
|
||||
iter_counter = IterationCounter(opt, len(dataloader))
|
||||
# create trainer for our model
|
||||
trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch)
|
||||
save_root = os.path.join('checkpoints', opt.name, 'train')
|
||||
mkdir(save_root)
|
||||
|
||||
for epoch in iter_counter.training_epochs():
|
||||
opt.epoch = epoch
|
||||
iter_counter.record_epoch_start(epoch)
|
||||
for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
|
||||
iter_counter.record_one_iteration()
|
||||
# Training
|
||||
# train generator
|
||||
if i % opt.D_steps_per_G == 0:
|
||||
trainer.run_generator_one_step(data_i)
|
||||
# train discriminator
|
||||
trainer.run_discriminator_one_step(data_i)
|
||||
if iter_counter.needs_printing():
|
||||
losses = trainer.get_latest_losses()
|
||||
try:
|
||||
print_current_errors(opt, epoch, iter_counter.epoch_iter,
|
||||
iter_counter.epoch_iter_num, losses, iter_counter.time_per_iter)
|
||||
except OSError as err:
|
||||
print(err)
|
||||
|
||||
if iter_counter.needs_displaying():
|
||||
imgs_num = data_i['label'].shape[0]
|
||||
|
||||
if opt.dataset_mode == 'deepfashionHD':
|
||||
label = data_i['label'][:,:3,:,:]
|
||||
|
||||
show_size = opt.display_winsize
|
||||
|
||||
imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), \
|
||||
trainer.get_latest_generated().data.cpu(), \
|
||||
data_i['image'].cpu()), 0)
|
||||
|
||||
try:
|
||||
save_name = '%08d_%08d.png' % (epoch, iter_counter.total_steps_so_far)
|
||||
save_name = os.path.join(save_root, save_name)
|
||||
save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
|
||||
except OSError as err:
|
||||
print(err)
|
||||
|
||||
if iter_counter.needs_saving():
|
||||
print('saving the latest model (epoch %d, total_steps %d)' %
|
||||
(epoch, iter_counter.total_steps_so_far))
|
||||
try:
|
||||
trainer.save('latest')
|
||||
iter_counter.record_current_iter()
|
||||
except OSError as err:
|
||||
import pdb; pdb.set_trace()
|
||||
print(err)
|
||||
|
||||
trainer.update_learning_rate(epoch)
|
||||
iter_counter.record_epoch_end()
|
||||
|
||||
if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs:
|
||||
print('saving the model at the end of epoch %d, iters %d' %
|
||||
(epoch, iter_counter.total_steps_so_far))
|
||||
try:
|
||||
trainer.save('latest')
|
||||
trainer.save(epoch)
|
||||
except OSError as err:
|
||||
print(err)
|
||||
|
||||
print('Training was successfully finished.')
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import sys
|
||||
import torch
|
||||
from models.pix2pix_model import Pix2PixModel
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler
|
||||
except:
|
||||
# dummy GradScaler for PyTorch < 1.6
|
||||
class GradScaler:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
def unscale_(self, optimizer):
|
||||
pass
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
class Pix2PixTrainer():
|
||||
"""
|
||||
Trainer creates the model and optimizers, and uses them to
|
||||
updates the weights of the network while reporting losses
|
||||
and the latest visuals to visualize the progress in training.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, resume_epoch=0):
|
||||
self.opt = opt
|
||||
self.pix2pix_model = Pix2PixModel(opt)
|
||||
if len(opt.gpu_ids) > 1:
|
||||
self.pix2pix_model = torch.nn.DataParallel(self.pix2pix_model, device_ids=opt.gpu_ids)
|
||||
self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
|
||||
else:
|
||||
self.pix2pix_model.to(opt.gpu_ids[0])
|
||||
self.pix2pix_model_on_one_gpu = self.pix2pix_model
|
||||
self.generated = None
|
||||
if opt.isTrain:
|
||||
self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(opt)
|
||||
self.old_lr = opt.lr
|
||||
if opt.continue_train and opt.which_epoch == 'latest':
|
||||
try:
|
||||
load_path = os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth')
|
||||
checkpoint = torch.load(load_path)
|
||||
self.optimizer_G.load_state_dict(checkpoint['G'])
|
||||
self.optimizer_D.load_state_dict(checkpoint['D'])
|
||||
except FileNotFoundError as err:
|
||||
print(err)
|
||||
print('Not find optimizer state dict: ' + load_path + '. Do not load optimizer!')
|
||||
|
||||
self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None
|
||||
self.g_losses = {}
|
||||
self.d_losses = {}
|
||||
self.scaler = GradScaler(enabled=self.opt.amp)
|
||||
|
||||
def run_generator_one_step(self, data):
|
||||
self.optimizer_G.zero_grad()
|
||||
g_losses, out = self.pix2pix_model(data, mode='generator')
|
||||
g_loss = sum(g_losses.values()).mean()
|
||||
# g_loss.backward()
|
||||
self.scaler.scale(g_loss).backward()
|
||||
self.scaler.unscale_(self.optimizer_G)
|
||||
# self.optimizer_G.step()
|
||||
self.scaler.step(self.optimizer_G)
|
||||
self.scaler.update()
|
||||
self.g_losses = g_losses
|
||||
self.out = out
|
||||
|
||||
def run_discriminator_one_step(self, data):
|
||||
self.optimizer_D.zero_grad()
|
||||
GforD = {}
|
||||
GforD['fake_image'] = self.out['fake_image']
|
||||
GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg']
|
||||
GforD['adaptive_feature_img'] = self.out['adaptive_feature_img']
|
||||
d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD)
|
||||
d_loss = sum(d_losses.values()).mean()
|
||||
# d_loss.backward()
|
||||
self.scaler.scale(d_loss).backward()
|
||||
self.scaler.unscale_(self.optimizer_D)
|
||||
# self.optimizer_D.step()
|
||||
self.scaler.step(self.optimizer_D)
|
||||
self.scaler.update()
|
||||
self.d_losses = d_losses
|
||||
|
||||
def get_latest_losses(self):
|
||||
return {**self.g_losses, **self.d_losses}
|
||||
|
||||
def get_latest_generated(self):
|
||||
return self.out['fake_image']
|
||||
|
||||
def update_learning_rate(self, epoch):
|
||||
self.update_learning_rate(epoch)
|
||||
|
||||
def save(self, epoch):
|
||||
self.pix2pix_model_on_one_gpu.save(epoch)
|
||||
if epoch == 'latest':
|
||||
torch.save({'G': self.optimizer_G.state_dict(), \
|
||||
'D': self.optimizer_D.state_dict(), \
|
||||
'lr': self.old_lr,}, \
|
||||
os.path.join(self.opt.checkpoints_dir, self.opt.name, 'optimizer.pth'))
|
||||
|
||||
def update_learning_rate(self, epoch):
|
||||
if epoch > self.opt.niter:
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
new_lr = self.old_lr - lrd
|
||||
else:
|
||||
new_lr = self.old_lr
|
||||
if new_lr != self.old_lr:
|
||||
new_lr_G = new_lr
|
||||
new_lr_D = new_lr
|
||||
else:
|
||||
new_lr_G = self.old_lr
|
||||
new_lr_D = self.old_lr
|
||||
for param_group in self.optimizer_D.param_groups:
|
||||
param_group['lr'] = new_lr_D
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = new_lr_G
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
|
||||
self.old_lr = new_lr
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
# Helper class that keeps track of training iterations
|
||||
class IterationCounter():
|
||||
def __init__(self, opt, dataset_size):
|
||||
self.opt = opt
|
||||
self.dataset_size = dataset_size
|
||||
self.batch_size = opt.batchSize
|
||||
self.first_epoch = 1
|
||||
self.total_epochs = opt.niter + opt.niter_decay
|
||||
# iter number within each epoch
|
||||
self.epoch_iter = 0
|
||||
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
|
||||
if opt.isTrain and opt.continue_train:
|
||||
try:
|
||||
self.first_epoch, self.epoch_iter = np.loadtxt(self.iter_record_path, delimiter=',', dtype=int)
|
||||
print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
|
||||
except:
|
||||
print('Could not load iteration record at %s. Starting from beginning.' % self.iter_record_path)
|
||||
self.epoch_iter_num = dataset_size * self.batch_size
|
||||
self.total_steps_so_far = (self.first_epoch - 1) * self.epoch_iter_num + self.epoch_iter
|
||||
self.continue_train_flag = opt.continue_train
|
||||
|
||||
# return the iterator of epochs for the training
|
||||
def training_epochs(self):
|
||||
return range(self.first_epoch, self.total_epochs + 1)
|
||||
|
||||
def record_epoch_start(self, epoch):
|
||||
self.epoch_start_time = time.time()
|
||||
if not self.continue_train_flag:
|
||||
self.epoch_iter = 0
|
||||
else:
|
||||
self.continue_train_flag = False
|
||||
self.last_iter_time = time.time()
|
||||
self.current_epoch = epoch
|
||||
|
||||
def record_one_iteration(self):
|
||||
current_time = time.time()
|
||||
# the last remaining batch is dropped (see data/__init__.py),
|
||||
# so we can assume batch size is always opt.batchSize
|
||||
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
|
||||
self.last_iter_time = current_time
|
||||
self.total_steps_so_far += self.opt.batchSize
|
||||
self.epoch_iter += self.opt.batchSize
|
||||
|
||||
def record_epoch_end(self):
|
||||
current_time = time.time()
|
||||
self.time_per_epoch = current_time - self.epoch_start_time
|
||||
print('End of epoch %d / %d \t Time Taken: %d sec' %
|
||||
(self.current_epoch, self.total_epochs, self.time_per_epoch))
|
||||
if self.current_epoch % self.opt.save_epoch_freq == 0:
|
||||
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0),
|
||||
delimiter=',', fmt='%d')
|
||||
print('Saved current iteration count at %s.' % self.iter_record_path)
|
||||
|
||||
def record_current_iter(self):
|
||||
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
|
||||
delimiter=',', fmt='%d')
|
||||
print('Saved current iteration count at %s.' % self.iter_record_path)
|
||||
|
||||
def needs_saving(self):
|
||||
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
|
||||
|
||||
def needs_printing(self):
|
||||
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
|
||||
|
||||
def needs_displaying(self):
|
||||
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import torch
|
||||
import numpy as np
|
||||
import importlib
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
def feature_normalize(feature_in, eps=1e-10):
|
||||
feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + eps
|
||||
feature_in_norm = torch.div(feature_in, feature_in_norm)
|
||||
return feature_in_norm
|
||||
|
||||
|
||||
def weighted_l1_loss(input, target, weights):
|
||||
out = torch.abs(input - target)
|
||||
out = out * weights.expand_as(out)
|
||||
loss = out.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def mse_loss(input, target=0):
|
||||
return torch.mean((input - target)**2)
|
||||
|
||||
|
||||
def vgg_preprocess(tensor, vgg_normal_correct=False):
|
||||
if vgg_normal_correct:
|
||||
tensor = (tensor + 1) / 2
|
||||
# input is RGB tensor which ranges in [0,1]
|
||||
# output is BGR tensor which ranges in [0,255]
|
||||
tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
|
||||
# tensor_bgr = tensor[:, [2, 1, 0], ...]
|
||||
tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1)
|
||||
tensor_rst = tensor_bgr_ml * 255
|
||||
return tensor_rst
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def find_class_in_module(target_cls_name, module):
|
||||
target_cls_name = target_cls_name.replace('_', '').lower()
|
||||
clslib = importlib.import_module(module)
|
||||
cls = None
|
||||
for name, clsobj in clslib.__dict__.items():
|
||||
if name.lower() == target_cls_name:
|
||||
cls = clsobj
|
||||
if cls is None:
|
||||
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
|
||||
exit(0)
|
||||
return cls
|
||||
|
||||
|
||||
def save_network(net, label, epoch, opt):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, label)
|
||||
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
if len(opt.gpu_ids) and torch.cuda.is_available():
|
||||
net.cuda()
|
||||
|
||||
|
||||
def load_network(net, label, epoch, opt):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, label)
|
||||
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if not os.path.exists(save_path):
|
||||
print('not find model :' + save_path + ', do not load model!')
|
||||
return net
|
||||
weights = torch.load(save_path)
|
||||
try:
|
||||
net.load_state_dict(weights)
|
||||
except KeyError:
|
||||
print('key error, not load!')
|
||||
except RuntimeError as err:
|
||||
print(err)
|
||||
net.load_state_dict(weights, strict=False)
|
||||
print('loaded with strict = False')
|
||||
print('Load from ' + save_path)
|
||||
return net
|
||||
|
||||
|
||||
def print_current_errors(opt, epoch, i, num, errors, t):
|
||||
message = '(epoch: %d, iters: %d, finish: %.2f%%, time: %.3f) ' % (epoch, i, (i/num)*100.0, t)
|
||||
for k, v in errors.items():
|
||||
v = v.mean().float()
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
print(message)
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
with open(log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message)
|
Загрузка…
Ссылка в новой задаче