This commit is contained in:
xingranzh 2021-06-08 17:59:52 +08:00
Родитель 6836c6b570
Коммит cbc34ddd78
38 изменённых файлов: 2922 добавлений и 53 удалений

2
.gitattributes поставляемый Normal file
Просмотреть файл

@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

138
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,138 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

101
README.md
Просмотреть файл

@ -1,33 +1,78 @@
# Project
# CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation)<br>
![teaser](imgs/teaser.png)
> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
**CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation**<br>
**CVPR 2021, oral presentation**<br>
[Xingran Zhou](http://xingranzh.github.io/), [Bo Zhang](https://www.microsoft.com/en-us/research/people/zhanbo/), [Ting Zhang](https://www.microsoft.com/en-us/research/people/tinzhan/), [Pan Zhang](https://panzhang0212.github.io/), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](https://www.microsoft.com/en-us/research/people/doch/), [Zhongfei Zhang](https://www.cs.binghamton.edu/~zhongfei/), [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)<br>
Paper: https://arxiv.org/pdf/2012.02047.pdf<br>
Video: https://youtu.be/aBr1lOjm_FA<br>
Slides: https://github.com/xingranzh/CocosNet-v2/blob/master/slides/cocosnet_v2_slides.pdf<br>
As the maintainer of this project, please make a few updates:
Abstract: *We present the full-resolution correspondence learning for cross-domain images, which aids image translation. We adopt a hierarchical strategy that uses the correspondence from coarse level to guide the fine levels. At each hierarchy, the correspondence can be efficiently computed via PatchMatch that iteratively leverages the matchings from the neighborhood. Within each PatchMatch iteration, the ConvGRU module is employed to refine the current correspondence considering not only the matchings of larger context but also the historic estimates. The proposed CoCosNet v2, a GRU-assisted PatchMatch approach, is fully differentiable and highly efficient. When jointly trained with image translation, full-resolution semantic correspondence can be established in an unsupervised manner, which in turn facilitates the exemplar-based image translation. Experiments on diverse translation tasks show that CoCosNet v2 performs considerably better than state-of-the-art literature on producing high-resolution images.*<br>
## Installation
First please install dependencies for the experiment:
```bash
pip install -r requirements.txt
````
We recommend to install Pytorch version after `Pytorch 1.6.0` since we made use of [automatic mixed precision](https://pytorch.org/docs/stable/amp.html) for accelerating. (we used `Pytorch 1.7.0` in our experiments)<br>
## Prepare the dataset
First download the Deepfashion dataset (high resolution version) from [this link](https://drive.google.com/file/d/1bByKH1ciLXY70Bp8le_AVnjk-Hd4pe_i/view?usp=sharing). Note the file name is `img_highres.zip`. Unzip the file and rename it as `img`.<br>
If the password is necessary, please contact [this link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to access the dataset.<br>
We use [OpenPose](https://github.com/Hzzone/pytorch-openpose) to estimate pose of DeepFashion(HD). We offer the keypoints detection results used in our experiment in [this link](https://drive.google.com/file/d/1wxrqyb67Xu_IPyZzftLgBPHDTKGQP7Pk/view?usp=sharing). Download and unzip the results file.<br>
Since the original resolution of DeepfashionHD is 750x1101, we use a Python script to process the images to the resolution 512x512. You can find the script in `data/preprocess.py`. Note you need to download our train-val split lists `train.txt` and `val.txt` from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing) in this step.<br>
Finally create the root folder `DeepfashionHD`, and move the folders `img` and `pose` below it. Now the the directory structure is like:<br>
```
DeepfashionHD
└─── img
│ │
│ └─── MEN
│ │ │ ...
│ │
│ └─── WOMEN
│ │ ...
└─── pose
│ │
│ └─── MEN
│ │ │ ...
│ │
│ └─── WOMEN
│ │ ...
- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
```
## Inference Using Pretrained Model
The inference results are saved in the folder `checkpoints/deepfashionHD/test`. Download the pretrained model from [this link](https://drive.google.com/file/d/1ehkrKlf5s1gfpDNXO6AC9SIZMtqs5L3N/view?usp=sharing).<br>
Move the models below the folder `checkpoints/deepfashionHD`. Then run the following command.
````bash
python test.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --PONO --PONO_C --no_flip --batchSize 8 --gpu_ids 0 --netCorr NoVGGHPM --nThreads 16 --nef 32 --amp --display_winsize 512 --iteration_count 5 --load_size 512 --crop_size 512
````
The inference results are saved in the folder `checkpoints/deepfashionHD/test`.<br>
## Training from scratch
Make sure you have prepared the DeepfashionHD dataset as the instruction.<br>
Download the **pretrained VGG model** from [this link](https://drive.google.com/file/d/1D-z73DOt63BrPTgIxffN6Q4_L9qma9y8/view?usp=sharing), move it to `vgg/` folder. We use this model to calculate training loss.<br>
Download the train-val lists from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing), and the retrival pair lists from [this link](https://drive.google.com/drive/folders/1dJU8iq8kFiwq33nWtvj5Ql5rUh9fiXUi?usp=sharing). Note `train.txt` and `val.txt` are our train-val lists. `deepfashion_ref.txt`, `deepfashion_ref_test.txt` and `deepfashion_self_pair.txt` are the paring lists used in our experiment. Download them all and move below the folder `data/`.<br>
Run the following command for training from scratch.
````bash
python train.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --niter 100 --niter_decay 0 --real_reference_probability 0.0 --hard_reference_probability 0.0 --which_perceptual 4_2 --weight_perceptual 0.001 --PONO --PONO_C --vgg_normal_correct --weight_fm_ratio 1.0 --no_flip --video_like --batchSize 16 --gpu_ids 0,1,2,3,4,5,6,7 --netCorr NoVGGHPM --match_kernel 1 --featEnc_kernel 3 --display_freq 500 --print_freq 50 --save_latest_freq 2500 --save_epoch_freq 5 --nThreads 16 --weight_warp_self 500.0 --lr 0.0001 --nef 32 --amp --weight_warp_cycle 1.0 --display_winsize 512 --iteration_count 5 --temperature 0.01 --continue_train --load_size 550 --crop_size 512 --which_epoch 15
````
Note that `--dataroot` parameter is your DeepFashionHD dataset root, e.g. `dataset/DeepFashionHD`.<br>
We use 8 32GB Tesla V100 GPUs to train the network. You can set `batchSize` to 16, 8 or 4 with fewer GPUs and change `gpu_ids`.
## Citation
If you use this code for your research, please cite our papers.
```
@inproceedings{zhou2021full,
title={CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation},
author={Zhou, Xingran and Zhang, Bo and Zhang, Ting and Zhang, Pan and Bao, Jianmin and Chen, Dong and Zhang, Zhongfei and Wen, Fang},
booktitle={CVPR},
year={2021}
}
```
## Acknowledgments
*This code borrows heavily from [CocosNet](https://github.com/microsoft/CoCosNet).
We also thank [SPADE](https://github.com/NVlabs/SPADE) and [RAFT](https://github.com/princeton-vl/RAFT).*
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
## License
The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file.<br>
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

42
data/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise ValueError("In %s.py, there should be a subclass of BaseDataset "
"with class name that matches %s in lowercase." %
(dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataloader(opt):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset()
instance.initialize(opt)
print("Dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
shuffle=(opt.phase=='train'),
num_workers=int(opt.nThreads),
drop_last=(opt.phase=='train')
)
return dataloader

135
data/base_dataset.py Normal file
Просмотреть файл

@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess_mode == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess_mode == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
elif opt.preprocess_mode == 'scale_shortside_and_crop':
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(opt.load_size * ls / ss)
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
transform_list = []
if opt.dataset_mode == 'flickr' and method == Image.NEAREST:
transform_list.append(transforms.Lambda(lambda img: __add1(img)))
if 'resize' in opt.preprocess_mode:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, interpolation=method))
elif 'scale_width' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
elif 'scale_shortside' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
if 'crop' in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess_mode == 'none':
base = 32
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.preprocess_mode == 'fixed':
w = opt.crop_size
h = round(opt.crop_size / opt.aspect_ratio)
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if opt.isTrain and 'rotate' in params.keys():
transform_list.append(transforms.Lambda(lambda img: __rotate(img, params['rotate'], method)))
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __resize(img, w, h, method=Image.BICUBIC):
return img.resize((w, h), method)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __scale_shortside(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
width_is_shorter = ow == ss
if (ss == target_width):
return img
ls = int(target_width * ls / ss)
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
return img.resize((nw, nh), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
return img.crop((x1, y1, x1 + tw, y1 + th))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __rotate(img, deg, method=Image.BICUBIC):
return img.rotate(deg, resample=method)
def __add1(img):
return Image.fromarray(np.array(img) + 1)

Просмотреть файл

@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import cv2
import torch
import numpy as np
import math
import random
from PIL import Image
from data.pix2pix_dataset import Pix2pixDataset
from data.base_dataset import get_params, get_transform
class DeepFashionHDDataset(Pix2pixDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode='resize_and_crop')
parser.set_defaults(no_pairing_check=True)
parser.set_defaults(load_size=550)
parser.set_defaults(crop_size=512)
parser.set_defaults(label_nc=20)
parser.set_defaults(contain_dontcare_label=False)
parser.set_defaults(cache_filelist_read=False)
parser.set_defaults(cache_filelist_write=False)
return parser
def get_paths(self, opt):
root = opt.dataroot
if opt.phase == 'train':
fd = open(os.path.join('./data/train.txt'))
lines = fd.readlines()
fd.close()
elif opt.phase == 'test':
fd = open(os.path.join('./data/val.txt'))
lines = fd.readlines()
fd.close()
image_paths = []
label_paths = []
for i in range(len(lines)):
name = lines[i].strip()
image_paths.append(name)
label_path = name.replace('img', 'pose').replace('.jpg', '_{}.txt')
label_paths.append(os.path.join(label_path))
return label_paths, image_paths
def get_ref_video_like(self, opt):
pair_path = './data/deepfashion_self_pair.txt'
with open(pair_path) as fd:
self_pair = fd.readlines()
self_pair = [it.strip() for it in self_pair]
self_pair_dict = {}
for it in self_pair:
items = it.split(',')
self_pair_dict[items[0]] = items[1:]
ref_path = './data/deepfashion_ref_test.txt' if opt.phase == 'test' else './data/deepfashion_ref.txt'
with open(ref_path) as fd:
ref = fd.readlines()
ref = [it.strip() for it in ref]
ref_dict = {}
for i in range(len(ref)):
items = ref[i].strip().split(',')
if items[0] in self_pair_dict.keys():
ref_dict[items[0]] = [it for it in self_pair_dict[items[0]]]
else:
ref_dict[items[0]] = [items[-1]]
train_test_folder = ('', '')
return ref_dict, train_test_folder
def get_ref_vgg(self, opt):
extra = ''
if opt.phase == 'test':
extra = '_test'
with open('./data/deepfashion_ref{}.txt'.format(extra)) as fd:
lines = fd.readlines()
ref_dict = {}
for i in range(len(lines)):
items = lines[i].strip().split(',')
key = items[0]
if opt.phase == 'test':
val = [it for it in items[1:]]
else:
val = [items[-1]]
ref_dict[key] = val
train_test_folder = ('', '')
return ref_dict, train_test_folder
def get_ref(self, opt):
if opt.video_like:
return self.get_ref_video_like(opt)
else:
return self.get_ref_vgg(opt)
def get_label_tensor(self, path):
candidate = np.loadtxt(path.format('candidate'))
subset = np.loadtxt(path.format('subset'))
stickwidth = 20
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
[1, 16], [16, 18], [3, 17], [6, 18]]
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
canvas = np.zeros((1024, 1024, 3), dtype=np.uint8)
cycle_radius = 20
for i in range(18):
index = int(subset[i])
if index == -1:
continue
x, y = candidate[index][0:2]
cv2.circle(canvas, (int(x), int(y)), cycle_radius, colors[i], thickness=-1)
joints = []
for i in range(17):
index = subset[np.array(limbSeq[i]) - 1]
cur_canvas = canvas.copy()
if -1 in index:
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
continue
Y = candidate[index.astype(int), 0]
X = candidate[index.astype(int), 1]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
joint = np.zeros_like(cur_canvas[:, :, 0])
cv2.fillConvexPoly(joint, polygon, 255)
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
joints.append(joint)
pose = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)).resize((self.opt.load_size, self.opt.load_size), resample=Image.NEAREST)
params = get_params(self.opt, pose.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
transform_img = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
tensors_dist = 0
e = 1
for i in range(len(joints)):
im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
im_dist = np.clip((im_dist/3), 0, 255).astype(np.uint8)
tensor_dist = transform_img(Image.fromarray(im_dist))
tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
e += 1
tensor_pose = transform_label(pose)
label_tensor = torch.cat((tensor_pose, tensors_dist), dim=0)
return label_tensor, params
def imgpath_to_labelpath(self, path):
label_path = path.replace('/img/', '/pose/').replace('.jpg', '_{}.txt')
return label_path
def labelpath_to_imgpath(self, path):
img_path = path.replace('/pose/', '/img/').replace('_{}.txt', '.jpg')
return img_path

127
data/pix2pix_dataset.py Normal file
Просмотреть файл

@ -0,0 +1,127 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import random
from PIL import Image
from data.base_dataset import BaseDataset, get_params, get_transform
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--no_pairing_check', action='store_true', help='If specified, skip sanity check of correct label-image file pairing')
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths = self.get_paths(opt)
label_paths = label_paths[:opt.max_dataset_size]
image_paths = image_paths[:opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), \
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2)
self.label_paths = label_paths
self.image_paths = image_paths
size = len(self.label_paths)
self.dataset_size = size
self.real_reference_probability = 1 if opt.phase == 'test' else opt.real_reference_probability
self.hard_reference_probability = 0 if opt.phase == 'test' else opt.hard_reference_probability
self.ref_dict, self.train_test_folder = self.get_ref(opt)
def get_paths(self, opt):
label_paths = []
image_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def get_label_tensor(self, path):
label = Image.open(path)
params1 = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params1, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc
# 'unknown' is opt.label_nc
return label_tensor, params1
def __getitem__(self, index):
# label Image
label_path = self.label_paths[index]
label_path = os.path.join(self.opt.dataroot, label_path)
label_tensor, params1 = self.get_label_tensor(label_path)
# input image (real images)
image_path = self.image_paths[index]
image_path = os.path.join(self.opt.dataroot, image_path)
image = Image.open(image_path).convert('RGB')
transform_image = get_transform(self.opt, params1)
image_tensor = transform_image(image)
ref_tensor = 0
label_ref_tensor = 0
random_p = random.random()
if random_p < self.real_reference_probability or self.opt.phase == 'test':
key = image_path.split('deepfashionHD/')[-1]
val = self.ref_dict[key]
if random_p < self.hard_reference_probability:
#hard reference
path_ref = val[1]
else:
#easy reference
path_ref = val[0]
if self.opt.dataset_mode == 'deepfashionHD':
path_ref = os.path.join(self.opt.dataroot, path_ref)
else:
path_ref = os.path.dirname(image_path).replace(self.train_test_folder[1], self.train_test_folder[0]) + '/' + path_ref
image_ref = Image.open(path_ref).convert('RGB')
if self.opt.dataset_mode != 'deepfashionHD':
path_ref_label = path_ref.replace('.jpg', '.png')
path_ref_label = self.imgpath_to_labelpath(path_ref_label)
else:
path_ref_label = self.imgpath_to_labelpath(path_ref)
label_ref_tensor, params = self.get_label_tensor(path_ref_label)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image_ref)
self_ref_flag = 0.0
else:
pair = False
if self.opt.dataset_mode == 'deepfashionHD' and self.opt.video_like:
key = image_path.replace('\\', '/').split('deepfashionHD/')[-1]
val = self.ref_dict[key]
ref_name = val[0]
key_name = key
path_ref = os.path.join(self.opt.dataroot, ref_name)
image_ref = Image.open(path_ref).convert('RGB')
label_ref_path = self.imgpath_to_labelpath(path_ref)
label_ref_tensor, params = self.get_label_tensor(label_ref_path)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image_ref)
pair = True
if not pair:
label_ref_tensor, params = self.get_label_tensor(label_path)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image)
self_ref_flag = 1.0
input_dict = {'label': label_tensor,
'image': image_tensor,
'path': image_path,
'self_ref': self_ref_flag,
'ref': ref_tensor,
'label_ref': label_ref_tensor
}
return input_dict
def __len__(self):
return self.dataset_size
def get_ref(self, opt):
pass
def imgpath_to_labelpath(self, path):
return path

Двоичные данные
imgs/teaser.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 3.2 MiB

Двоичные данные
imgs/teaser_1.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 344 KiB

Двоичные данные
imgs/teaser_2.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 2.0 MiB

38
models/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import importlib
def find_model_using_name(model_name):
# Given the option --model [modelname],
# the file "models/modelname_model.py"
# will be imported.
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
# In the file, the class called ModelNameModel() will
# be instantiated. It has to be a subclass of torch.nn.Module,
# and it is case-insensitive.
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, torch.nn.Module):
model = cls
if model is None:
print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % (type(instance).__name__))
return instance

Просмотреть файл

@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from util.util import feature_normalize, mse_loss
class ContextualLoss_forward(nn.Module):
'''
input is Al, Bl, channel = 1, range ~ [0, 255]
'''
def __init__(self, opt):
super(ContextualLoss_forward, self).__init__()
self.opt = opt
return None
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
'''
X_features&Y_features are are feature vectors or feature 2d array
h: bandwidth
return the per-sample loss
'''
batch_size = X_features.shape[0]
feature_depth = X_features.shape[1]
feature_size = X_features.shape[2]
# to normalized feature vectors
if feature_centering:
if self.opt.PONO:
X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
else:
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
# X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
# Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
X_features = feature_normalize(X_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
Y_features = feature_normalize(Y_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
# X_features = F.unfold(
# X_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
# Y_features = F.unfold(
# Y_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
# conine distance = 1 - similarity
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
# normalized distance: dij_bar
# d_norm = d
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-3) # batch_size * feature_size^2 * feature_size^2
# pairwise affinity
w = torch.exp((1 - d_norm) / h)
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
# contextual loss per sample
CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
loss = -torch.log(CX)
# contextual loss per batch
# loss = torch.mean(loss)
return loss

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from models.networks.base_network import BaseNetwork
from models.networks.loss import *
from models.networks.discriminator import *
from models.networks.generator import *
from models.networks.ContextualLoss import *
from models.networks.correspondence import *
from models.networks.ops import *
import util.util as util
def find_network_using_name(target_network_name, filename, add=True):
target_class_name = target_network_name + filename if add else target_network_name
module_name = 'models.networks.' + filename
network = util.find_class_in_module(target_class_name, module_name)
assert issubclass(network, BaseNetwork), \
"Class %s should be a subclass of BaseNetwork" % network
return network
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
netG_cls = find_network_using_name(opt.netG, 'generator')
parser = netG_cls.modify_commandline_options(parser, is_train)
if is_train:
netD_cls = find_network_using_name(opt.netD, 'discriminator')
parser = netD_cls.modify_commandline_options(parser, is_train)
return parser
def create_network(cls, opt):
net = cls(opt)
net.print_network()
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
net.cuda()
net.init_weights(opt.init_type, opt.init_variance)
return net
def define_G(opt):
netG_cls = find_network_using_name(opt.netG, 'generator')
return create_network(netG_cls, opt)
def define_D(opt):
netD_cls = find_network_using_name(opt.netD, 'discriminator')
return create_network(netD_cls, opt)
def define_Corr(opt):
netCoor_cls = find_network_using_name(opt.netCorr, 'correspondence')
return create_network(netCoor_cls, opt)

Просмотреть файл

@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from models.networks.normalization import SPADE
from util.util import vgg_preprocess
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
super(ResidualBlock, self).__init__()
self.relu = nn.PReLU()
self.model = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
nn.InstanceNorm2d(out_channels),
self.relu,
nn.ReflectionPad2d(padding),
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
nn.InstanceNorm2d(out_channels),
)
def forward(self, x):
out = self.relu(x + self.model(x))
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.opt = opt
self.pad_type = 'nozero'
self.use_se = use_se
# create conv layers
if self.pad_type != 'zero':
self.pad = nn.ReflectionPad2d(dilation)
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=0, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=0, dilation=dilation)
else:
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace('spectral', '')
if 'spade_ic' in opt:
ic = opt.spade_ic
else:
ic = 4*3+opt.label_nc
self.norm_0 = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
self.norm_1 = SPADE(spade_config_str, fmiddle, ic, PONO=opt.PONO)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
if self.pad_type != 'zero':
dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, seg1))))
dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, seg1))))
else:
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class VGG19_feature_color_torchversion(nn.Module):
"""
NOTE: there is no need to pre-process the input
input tensor should range in [0,1]
"""
def __init__(self, pool='max', vgg_normal_correct=False, ic=3):
super(VGG19_feature_color_torchversion, self).__init__()
self.vgg_normal_correct = vgg_normal_correct
self.conv1_1 = nn.Conv2d(ic, 64, kernel_size=3, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
if pool == 'max':
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
elif pool == 'avg':
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x, out_keys, preprocess=True):
'''
NOTE: input tensor should range in [0,1]
'''
out = {}
if preprocess:
x = vgg_preprocess(x, vgg_normal_correct=self.vgg_normal_correct)
out['r11'] = F.relu(self.conv1_1(x))
out['r12'] = F.relu(self.conv1_2(out['r11']))
out['p1'] = self.pool1(out['r12'])
out['r21'] = F.relu(self.conv2_1(out['p1']))
out['r22'] = F.relu(self.conv2_2(out['r21']))
out['p2'] = self.pool2(out['r22'])
out['r31'] = F.relu(self.conv3_1(out['p2']))
out['r32'] = F.relu(self.conv3_2(out['r31']))
out['r33'] = F.relu(self.conv3_3(out['r32']))
out['r34'] = F.relu(self.conv3_4(out['r33']))
out['p3'] = self.pool3(out['r34'])
out['r41'] = F.relu(self.conv4_1(out['p3']))
out['r42'] = F.relu(self.conv4_2(out['r41']))
out['r43'] = F.relu(self.conv4_3(out['r42']))
out['r44'] = F.relu(self.conv4_4(out['r43']))
out['p4'] = self.pool4(out['r44'])
out['r51'] = F.relu(self.conv5_1(out['p4']))
out['r52'] = F.relu(self.conv5_2(out['r51']))
out['r53'] = F.relu(self.conv5_3(out['r52']))
out['r54'] = F.relu(self.conv5_4(out['r53']))
out['p5'] = self.pool5(out['r54'])
return [out[key] for key in out_keys]

Просмотреть файл

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
from torch.nn import init
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).'
% (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)

Просмотреть файл

@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=32, hidden_dim=64):
super(FlowHead, self).__init__()
candidate_num = 16
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2*candidate_num, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
num = x.size()[1]
delta_offset_x, delta_offset_y = torch.split(x, [num//2, num//2], dim=1)
return delta_offset_x, delta_offset_y
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=32, input_dim=64):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self):
super(BasicMotionEncoder, self).__init__()
candidate_num = 16
self.convc1 = nn.Conv2d(candidate_num, 64, 1, padding=0)
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
self.convf1 = nn.Conv2d(2*candidate_num, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv = nn.Conv2d(64+64, 64-2*candidate_num, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicUpdateBlock(nn.Module):
def __init__(self):
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder()
self.gru = SepConvGRU()
self.flow_head = FlowHead()
def forward(self, net, corr, flow):
motion_features = self.encoder(flow, corr)
inp = motion_features
net = self.gru(net, inp)
delta_offset_x, delta_offset_y = self.flow_head(net)
return net, delta_offset_x, delta_offset_y

Просмотреть файл

@ -0,0 +1,245 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import util.util as util
from models.networks.base_network import BaseNetwork
from models.networks.architecture import ResidualBlock
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import SPADEResnetBlock
"""patch match"""
from models.networks.patch_match import PatchMatchGRU
from models.networks.ops import *
def match_kernel_and_pono_c(feature, match_kernel, PONO_C, eps=1e-10):
b, c, h, w = feature.size()
if match_kernel == 1:
feature = feature.view(b, c, -1)
else:
feature = F.unfold(feature, kernel_size=match_kernel, padding=int(match_kernel//2))
dim_mean = 1 if PONO_C else -1
feature = feature - feature.mean(dim=dim_mean, keepdim=True)
feature_norm = torch.norm(feature, 2, 1, keepdim=True) + eps
feature = torch.div(feature, feature_norm)
return feature.view(b, -1, h, w)
"""512x512"""
class AdaptiveFeatureGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
kw = opt.featEnc_kernel
pw = int((kw-1)//2)
nf = opt.nef
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
self.layer1 = norm_layer(nn.Conv2d(opt.spade_ic, nf, 3, stride=1, padding=pw))
self.layer2 = nn.Sequential(
norm_layer(nn.Conv2d(nf * 1, nf * 2, 3, 1, 1)),
ResidualBlock(nf * 2, nf * 2),
)
self.layer3 = nn.Sequential(
norm_layer(nn.Conv2d(nf * 2, nf * 4, kw, stride=2, padding=pw)),
ResidualBlock(nf * 4, nf * 4),
)
self.layer4 = nn.Sequential(
norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
ResidualBlock(nf * 4, nf * 4),
)
self.layer5 = nn.Sequential(
norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
ResidualBlock(nf * 4, nf * 4),
)
self.head_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
self.G_middle_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
self.G_middle_1 = SPADEResnetBlock(nf * 4, nf * 2, opt)
self.G_middle_2 = SPADEResnetBlock(nf * 2, nf * 1, opt)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, input, seg):
# 512
x1 = self.layer1(input)
# 512
x2 = self.layer2(self.actvn(x1))
# 256
x3 = self.layer3(self.actvn(x2))
# 128
x4 = self.layer4(self.actvn(x3))
# 64
x5 = self.layer5(self.actvn(x4))
# bottleneck
x6 = self.head_0(x5, seg)
# 128
x7 = self.G_middle_0(self.up(x6) + x4, seg)
# 256
x8 = self.G_middle_1(self.up(x7) + x3, seg)
# 512
x9 = self.G_middle_2(self.up(x8) + x2, seg)
return [x6, x7, x8, x9]
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class NoVGGHPMCorrespondence(BaseNetwork):
def __init__(self, opt):
self.opt = opt
super().__init__()
opt.spade_ic = opt.semantic_nc
self.adaptive_model_seg = AdaptiveFeatureGenerator(opt)
opt.spade_ic = 3 + opt.semantic_nc
self.adaptive_model_img = AdaptiveFeatureGenerator(opt)
del opt.spade_ic
self.batch_size = opt.batchSize
"""512x512"""
feature_channel = opt.nef
self.phi_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.phi_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.phi_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.phi_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.theta_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.theta_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.theta_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.theta_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
self.patch_match = PatchMatchGRU(opt)
"""512x512"""
def multi_scale_patch_match(self, f1, f2, ref, hierarchical_scale, pre=None, real_img=None):
if hierarchical_scale == 0:
y_cycle = None
scale = 64
batch_size, channel, feature_height, feature_width = f1.size()
ref = F.avg_pool2d(ref, 8, stride=8)
ref = ref.view(batch_size, 3, scale * scale)
f1 = f1.view(batch_size, channel, scale * scale)
f2 = f2.view(batch_size, channel, scale * scale)
matmul_result = torch.matmul(f1.permute(0, 2, 1), f2)/self.opt.temperature
mat = F.softmax(matmul_result, dim=-1)
y = torch.matmul(mat, ref.permute(0, 2, 1))
if self.opt.phase is 'train' and self.opt.weight_warp_cycle > 0:
mat_cycle = F.softmax(matmul_result.transpose(1, 2), dim=-1)
y_cycle = torch.matmul(mat_cycle, y)
y_cycle = y_cycle.permute(0, 2, 1).view(batch_size, 3, scale, scale)
y = y.permute(0, 2, 1).view(batch_size, 3, scale, scale)
return mat, y, y_cycle
if hierarchical_scale == 1:
scale = 128
with torch.no_grad():
batch_size, channel, feature_height, feature_width = f1.size()
topk_num = 1
search_window = 4
centering = 1
dilation = 2
total_candidate_num = topk_num * (search_window ** 2)
topk_inds = torch.topk(pre, topk_num, dim=-1)[-1]
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
offset_x, offset_y = inds_to_offset(inds)
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
ref = F.avg_pool2d(ref, 4, stride=4)
ref = ref.view(batch_size, 3, scale * scale)
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
y = y.view(batch_size, 3, scale, scale)
return mat, y
if hierarchical_scale == 2:
scale = 256
with torch.no_grad():
batch_size, channel, feature_height, feature_width = f1.size()
topk_num = 1
search_window = 4
centering = 1
dilation = 2
total_candidate_num = topk_num * (search_window ** 2)
topk_inds = pre[:, :, :topk_num]
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
offset_x, offset_y = inds_to_offset(inds)
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
ref = F.avg_pool2d(ref, 2, stride=2)
ref = ref.view(batch_size, 3, scale * scale)
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
y = y.view(batch_size, 3, scale, scale)
return mat, y
if hierarchical_scale == 3:
scale = 512
with torch.no_grad():
batch_size, channel, feature_height, feature_width = f1.size()
topk_num = 1
search_window = 4
centering = 1
dilation = 2
total_candidate_num = topk_num * (search_window ** 2)
topk_inds = pre[:, :, :topk_num]
inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
offset_x, offset_y = inds_to_offset(inds)
dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
dx = dx.view(1, search_window ** 2, 1, 1) * dilation
dy = dy.view(1, search_window ** 2, 1, 1) * dilation
offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
offset_y_up = F.interpolate((2 * offset_y + dx), scale_factor=2)
ref = ref.view(batch_size, 3, scale * scale)
mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
y = y.view(batch_size, 3, scale, scale)
return mat, y
def forward(self, ref_img, real_img, seg_map, ref_seg_map):
corr_out = {}
seg_input = seg_map
adaptive_feature_seg = self.adaptive_model_seg(seg_input, seg_input)
ref_input = torch.cat((ref_img, ref_seg_map), dim=1)
adaptive_feature_img = self.adaptive_model_img(ref_input, ref_input)
for i in range(len(adaptive_feature_seg)):
adaptive_feature_seg[i] = util.feature_normalize(adaptive_feature_seg[i])
adaptive_feature_img[i] = util.feature_normalize(adaptive_feature_img[i])
if self.opt.isTrain and self.opt.weight_novgg_featpair > 0:
real_input = torch.cat((real_img, seg_map), dim=1)
adaptive_feature_img_pair = self.adaptive_model_img(real_input, real_input)
loss_novgg_featpair = 0
weights = [1.0, 1.0, 1.0, 1.0]
for i in range(len(adaptive_feature_img_pair)):
adaptive_feature_img_pair[i] = util.feature_normalize(adaptive_feature_img_pair[i])
loss_novgg_featpair += F.l1_loss(adaptive_feature_seg[i], adaptive_feature_img_pair[i]) * weights[i]
corr_out['loss_novgg_featpair'] = loss_novgg_featpair * self.opt.weight_novgg_featpair
cont_features = adaptive_feature_seg
ref_features = adaptive_feature_img
theta = []
phi = []
"""512x512"""
theta.append(match_kernel_and_pono_c(self.theta_0(cont_features[0]), self.opt.match_kernel, self.opt.PONO_C))
theta.append(match_kernel_and_pono_c(self.theta_1(cont_features[1]), self.opt.match_kernel, self.opt.PONO_C))
theta.append(match_kernel_and_pono_c(self.theta_2(cont_features[2]), self.opt.match_kernel, self.opt.PONO_C))
theta.append(match_kernel_and_pono_c(self.theta_3(cont_features[3]), self.opt.match_kernel, self.opt.PONO_C))
phi.append(match_kernel_and_pono_c(self.phi_0(ref_features[0]), self.opt.match_kernel, self.opt.PONO_C))
phi.append(match_kernel_and_pono_c(self.phi_1(ref_features[1]), self.opt.match_kernel, self.opt.PONO_C))
phi.append(match_kernel_and_pono_c(self.phi_2(ref_features[2]), self.opt.match_kernel, self.opt.PONO_C))
phi.append(match_kernel_and_pono_c(self.phi_3(ref_features[3]), self.opt.match_kernel, self.opt.PONO_C))
ref = ref_img
ys = []
m = None
for i in range(len(theta)):
if i == 0:
m, y, y_cycle = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
if y_cycle is not None:
corr_out['warp_cycle'] = y_cycle
else:
m, y = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
ys.append(y)
corr_out['warp_out'] = ys
return corr_out

Просмотреть файл

@ -0,0 +1,114 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
import util.util as util
class MultiscaleDiscriminator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--netD_subarch', type=str, default='n_layer',
help='architecture of each discriminator')
parser.add_argument('--num_D', type=int, default=2,
help='number of discriminators to be used in multiscale')
opt, _ = parser.parse_known_args()
# define properties of each discriminator of the multiscale discriminator
subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', \
'models.networks.discriminator')
subnetD.modify_commandline_options(parser, is_train)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
for i in range(opt.num_D):
subnetD = self.create_single_discriminator(opt)
self.add_module('discriminator_%d' % i, subnetD)
def create_single_discriminator(self, opt):
subarch = opt.netD_subarch
if subarch == 'n_layer':
netD = NLayerDiscriminator(opt)
else:
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
return netD
def downsample(self, input):
return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input):
result = []
get_intermediate_features = not self.opt.no_ganFeat_loss
for name, D in self.named_children():
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
class NLayerDiscriminator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator')
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
kw = 4
padw = int((kw - 1.0) / 2)
nf = opt.ndf
input_nc = self.compute_D_input_nc(opt)
norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
for n in range(1, opt.n_layers_D):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == opt.n_layers_D - 1 else 2
if n == opt.n_layers_D - 1:
dec = []
nc_dec = nf_prev
for _ in range(opt.n_layers_D - 1):
dec += [nn.Upsample(scale_factor=2),
norm_layer(nn.Conv2d(nc_dec, int(nc_dec//2), kernel_size=3, stride=1, padding=1)),
nn.LeakyReLU(0.2, False)]
nc_dec = int(nc_dec // 2)
dec += [nn.Conv2d(nc_dec, opt.semantic_nc, kernel_size=3, stride=1, padding=1)]
self.dec = nn.Sequential(*dec)
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def compute_D_input_nc(self, opt):
input_nc = opt.label_nc + opt.output_nc
if opt.contain_dontcare_label:
input_nc += 1
return input_nc
def forward(self, input):
results = [input]
seg = None
cam_logit = None
for name, submodel in self.named_children():
if 'model' not in name:
continue
x = results[-1]
intermediate_output = submodel(x)
results.append(intermediate_output)
get_intermediate_features = not self.opt.no_ganFeat_loss
if get_intermediate_features:
retu = results[1:]
else:
retu = results[-1]
return retu

Просмотреть файл

@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from models.networks.base_network import BaseNetwork
from models.networks.architecture import SPADEResnetBlock
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh = self.compute_latent_vector_size(opt)
ic = 4*3+opt.label_nc
self.fc = nn.Conv2d(ic, 8 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
self.G_middle_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
self.G_middle_1 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
self.up_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
final_nc = nf
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def compute_latent_vector_size(self, opt):
num_up_layers = 5
sw = opt.crop_size // (2**num_up_layers)
sh = round(sw / opt.aspect_ratio)
return sw, sh
def forward(self, input, warp_out=None):
seg = torch.cat((F.interpolate(warp_out[0], size=(512, 512)), F.interpolate(warp_out[1], size=(512, 512)), F.interpolate(warp_out[2], size=(512, 512)), warp_out[3], input), dim=1)
x = F.interpolate(seg, size=(self.sh, self.sw))
x = self.fc(x)
x = self.head_0(x, seg)
x = self.up(x)
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.up(x)
x = self.up_0(x, seg)
x = self.up(x)
x = self.up_1(x, seg)
x = self.up(x)
x = self.up_2(x, seg)
x = self.up(x)
x = self.up_3(x, seg)
x = self.conv_img(F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x

89
models/networks/loss.py Normal file
Просмотреть файл

@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor, opt=None):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
self.zero_tensor = None
self.Tensor = tensor
self.gan_mode = gan_mode
self.opt = opt
if gan_mode == 'ls':
pass
elif gan_mode == 'original':
pass
elif gan_mode == 'w':
pass
elif gan_mode == 'hinge':
pass
else:
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
def get_target_tensor(self, input, target_is_real):
if target_is_real:
if self.real_label_tensor is None:
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
self.real_label_tensor.requires_grad_(False)
return self.real_label_tensor.expand_as(input)
else:
if self.fake_label_tensor is None:
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
self.fake_label_tensor.requires_grad_(False)
return self.fake_label_tensor.expand_as(input)
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = self.Tensor(1).fill_(0)
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input).type_as(input)
def loss(self, input, target_is_real, for_discriminator=True):
if self.gan_mode == 'original': # cross entropy loss
target_tensor = self.get_target_tensor(input, target_is_real)
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
return loss
elif self.gan_mode == 'ls':
target_tensor = self.get_target_tensor(input, target_is_real)
return F.mse_loss(input, target_tensor)
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input)
return loss
else:
# wgan
if target_is_real:
return -input.mean()
else:
return input.mean()
def __call__(self, input, target_is_real, for_discriminator=True):
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
pred_i = pred_i[-1]
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input)
else:
return self.loss(input, target_is_real, for_discriminator)

Просмотреть файл

@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(opt, norm_type='instance'):
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
return layer.weight.size(0)
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
subnorm_type = norm_type[len('spectral'):]
else:
subnorm_type =norm_type
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
if getattr(layer, 'bias', None) is not None:
delattr(layer, 'bias')
layer.register_parameter('bias', None)
if subnorm_type == 'batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'sync_batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'instance':
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
def PositionalNorm2d(x, epsilon=1e-8):
# x: B*C*W*H normalize in C dim
mean = x.mean(dim=1, keepdim=True)
std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
output = (x - mean) / std
return output
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc, PONO=False):
super().__init__()
assert config_text.startswith('spade')
parsed = re.search('spade(\D+)(\d)x\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
self.pad_type = 'nozero'
if PONO:
self.param_free_norm = PositionalNorm2d
elif param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'syncbatch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
else:
raise ValueError('%s is not a recognized param-free norm type in SPADE' % param_free_norm_type)
nhidden = 128
pw = ks // 2
if self.pad_type != 'zero':
self.mlp_shared = nn.Sequential(
nn.ReflectionPad2d(pw),
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0),
nn.ReLU()
)
self.pad = nn.ReflectionPad2d(pw)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
else:
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
if self.pad_type != 'zero':
gamma = self.mlp_gamma(self.pad(actv))
beta = self.mlp_beta(self.pad(actv))
else:
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out

50
models/networks/ops.py Normal file
Просмотреть файл

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
def convert_1d_to_2d(index, base=64):
x = index // base
y = index % base
return x,y
def convert_2d_to_1d(x, y, base=64):
return x*base+y
def batch_meshgrid(shape, device):
batch_size, _, height, width = shape
x_range = torch.arange(0.0, width, device=device)
y_range = torch.arange(0.0, height, device=device)
x_coordinate, y_coordinate = torch.meshgrid(x_range, y_range)
x_coordinate = x_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
y_coordinate = y_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
return x_coordinate, y_coordinate
def inds_to_offset(inds):
"""
inds: b x number x h x w
"""
shape = inds.size()
device = inds.device
x_coordinate, y_coordinate = batch_meshgrid(shape, device)
batch_size, _, height, width = shape
x = inds // width
y = inds % width
return x - x_coordinate, y - y_coordinate
def offset_to_inds(offset_x, offset_y):
shape = offset_x.size()
device = offset_x.device
x_coordinate, y_coordinate = batch_meshgrid(shape, device)
h, w = offset_x.size()[2:]
x = torch.clamp(x_coordinate + offset_x, 0, h-1)
y = torch.clamp(y_coordinate + offset_y, 0, w-1)
return x * offset_x.size()[3] + y

Просмотреть файл

@ -0,0 +1,169 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from models.networks.convgru import BasicUpdateBlock
from models.networks.ops import *
"""patch match"""
class Evaluate(nn.Module):
def __init__(self, temperature):
super().__init__()
self.filter_size = 3
self.temperature = temperature
def forward(self, left_features, right_features, offset_x, offset_y):
device = left_features.get_device()
batch_size, num, height, width = offset_x.size()
channel = left_features.size()[1]
matching_inds = offset_to_inds(offset_x, offset_y)
matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
base_batch = torch.arange(batch_size).to(device).long() * (height * width)
base_batch = base_batch.view(-1, 1, 1)
matching_inds_add_base = matching_inds + base_batch
right_features_view = right_features
match_cost = []
# using A[:, idx]
for i in range(matching_inds_add_base.size()[-1]):
idx = matching_inds_add_base[:, :, i]
idx = idx.contiguous().view(-1)
right_features_select = right_features_view[:, idx]
right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
match_cost.append(match_cost_i)
match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
match_cost = F.softmax(match_cost, dim=-1)
match_cost_topk, match_cost_topk_indices = torch.topk(match_cost, num//self.filter_size, dim=-1)
matching_inds = torch.gather(matching_inds, -1, match_cost_topk_indices)
matching_inds = matching_inds.permute(0, 2, 1).view(batch_size, -1, height, width).float()
offset_x, offset_y = inds_to_offset(matching_inds)
corr = match_cost_topk.permute(0, 2, 1)
return offset_x, offset_y, corr
class PropagationFaster(nn.Module):
def __init__(self):
super().__init__()
def forward(self, offset_x, offset_y, propagation_type="horizontal"):
device = offset_x.get_device()
self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device)
self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device)
if propagation_type is "horizontal":
offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3),
offset_x,
torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3),
offset_y,
torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
else:
offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2),
offset_x,
torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2),
offset_y,
torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
return offset_x, offset_y
class PatchMatchOnce(nn.Module):
def __init__(self, opt):
super().__init__()
self.propagation = PropagationFaster()
self.evaluate = Evaluate(opt.temperature)
def forward(self, left_features, right_features, offset_x, offset_y):
prob = random.random()
if prob < 0.5:
offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
else:
offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
return offset_x, offset_y, corr
class PatchMatchGRU(nn.Module):
def __init__(self, opt):
super().__init__()
self.patch_match_one_step = PatchMatchOnce(opt)
self.temperature = opt.temperature
self.iters = opt.iteration_count
input_dim = opt.nef
hidden_dim = 32
norm = nn.InstanceNorm2d(hidden_dim, affine=False)
relu = nn.ReLU(inplace=True)
"""
concat left and right features
"""
self.initial_layer = nn.Sequential(
nn.Conv2d(input_dim*2, hidden_dim, kernel_size=3, padding=1, stride=1),
norm,
relu,
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
norm,
relu,
)
self.refine_net = BasicUpdateBlock()
def forward(self, left_features, right_features, right_input, initial_offset_x, initial_offset_y):
device = left_features.get_device()
batch_size, channel, height, width = left_features.size()
num = initial_offset_x.size()[1]
initial_input = torch.cat((left_features, right_features), dim=1)
hidden = self.initial_layer(initial_input)
left_features = left_features.view(batch_size, -1, height * width)
right_features = right_features.view(batch_size, -1, height * width)
right_features_view = right_features.transpose(0, 1).contiguous().view(channel, -1)
with torch.no_grad():
offset_x, offset_y = initial_offset_x, initial_offset_y
for it in range(self.iters):
with torch.no_grad():
offset_x, offset_y, corr = self.patch_match_one_step(left_features, right_features_view, offset_x, offset_y)
"""GRU refinement"""
flow = torch.cat((offset_x, offset_y), dim=1)
corr = corr.view(batch_size, -1, height, width)
hidden, delta_offset_x, delta_offset_y = self.refine_net(hidden, corr, flow)
offset_x = offset_x + delta_offset_x
offset_y = offset_y + delta_offset_y
with torch.no_grad():
matching_inds = offset_to_inds(offset_x, offset_y)
matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
base_batch = torch.arange(batch_size).to(device).long() * (height * width)
base_batch = base_batch.view(-1, 1, 1)
matching_inds_plus_base = matching_inds + base_batch
match_cost = []
# using A[:, idx]
for i in range(matching_inds_plus_base.size()[-1]):
idx = matching_inds_plus_base[:, :, i]
idx = idx.contiguous().view(-1)
right_features_select = right_features_view[:, idx]
right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
match_cost.append(match_cost_i)
match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
match_cost = F.softmax(match_cost, dim=-1)
right_input_view = right_input.transpose(0, 1).contiguous().view(right_input.size()[1], -1)
warp = torch.zeros_like(right_input)
# using A[:, idx]
for i in range(match_cost.size()[-1]):
idx = matching_inds_plus_base[:, :, i]
idx = idx.contiguous().view(-1)
right_input_select = right_input_view[:, idx]
right_input_select = right_input_select.view(right_input.size()[1], batch_size, -1).transpose(0, 1)
warp = warp + right_input_select * match_cost[:, :, i].unsqueeze(dim=1)
return matching_inds, warp

255
models/pix2pix_model.py Normal file
Просмотреть файл

@ -0,0 +1,255 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import models.networks as networks
import util.util as util
import itertools
try:
from torch.cuda.amp import autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class Pix2PixModel(torch.nn.Module):
@staticmethod
def modify_commandline_options(parser, is_train):
networks.modify_commandline_options(parser, is_train)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
else torch.FloatTensor
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
else torch.ByteTensor
self.net = torch.nn.ModuleDict(self.initialize_networks(opt))
# set loss functions
if opt.isTrain:
# vgg network
self.vggnet_fix = networks.architecture.VGG19_feature_color_torchversion(vgg_normal_correct=opt.vgg_normal_correct)
self.vggnet_fix.load_state_dict(torch.load('vgg/vgg19_conv.pth'))
self.vggnet_fix.eval()
for param in self.vggnet_fix.parameters():
param.requires_grad = False
self.vggnet_fix.to(self.opt.gpu_ids[0])
# contextual loss
self.contextual_forward_loss = networks.ContextualLoss_forward(opt)
# GAN loss
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
# L1 loss
self.criterionFeat = torch.nn.L1Loss()
# L2 loss
self.MSE_loss = torch.nn.MSELoss()
# setting which layer is used in the perceptual loss
if opt.which_perceptual == '5_2':
self.perceptual_layer = -1
elif opt.which_perceptual == '4_2':
self.perceptual_layer = -2
def forward(self, data, mode, GforD=None):
input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, )
generated_out = {}
if mode == 'generator':
g_loss, generated_out = self.compute_generator_loss(input_label, \
input_semantics, real_image, ref_label, \
ref_semantics, ref_image, self_ref)
out = {}
out['fake_image'] = generated_out['fake_image']
out['input_semantics'] = input_semantics
out['ref_semantics'] = ref_semantics
out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out']
out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg']
out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img']
out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle']
return g_loss, out
elif mode == 'discriminator':
d_loss = self.compute_discriminator_loss(input_semantics, \
real_image, GforD, label=input_label)
return d_loss
elif mode == 'inference':
out = {}
with torch.no_grad():
out = self.inference(input_semantics, ref_semantics=ref_semantics, \
ref_image=ref_image, self_ref=self_ref, \
real_image=real_image)
out['input_semantics'] = input_semantics
out['ref_semantics'] = ref_semantics
return out
else:
raise ValueError("|mode| is invalid")
def create_optimizers(self, opt):
if opt.no_TTUR:
beta1, beta2 = opt.beta1, opt.beta2
G_lr, D_lr = opt.lr, opt.lr
else:
beta1, beta2 = 0, 0.9
G_lr, D_lr = opt.lr / 2, opt.lr * 2
optimizer_G = torch.optim.Adam(itertools.chain(self.net['netG'].parameters(), \
self.net['netCorr'].parameters()), lr=G_lr, betas=(beta1, beta2), eps=1e-3)
optimizer_D = torch.optim.Adam(itertools.chain(self.net['netD'].parameters()), \
lr=D_lr, betas=(beta1, beta2))
return optimizer_G, optimizer_D
def save(self, epoch):
util.save_network(self.net['netG'], 'G', epoch, self.opt)
util.save_network(self.net['netD'], 'D', epoch, self.opt)
util.save_network(self.net['netCorr'], 'Corr', epoch, self.opt)
def initialize_networks(self, opt):
net = {}
net['netG'] = networks.define_G(opt)
net['netD'] = networks.define_D(opt) if opt.isTrain else None
net['netCorr'] = networks.define_Corr(opt)
if not opt.isTrain or opt.continue_train:
net['netCorr'] = util.load_network(net['netCorr'], 'Corr', opt.which_epoch, opt)
net['netG'] = util.load_network(net['netG'], 'G', opt.which_epoch, opt)
if opt.isTrain:
net['netD'] = util.load_network(net['netD'], 'D', opt.which_epoch, opt)
return net
def preprocess_input(self, data):
if self.use_gpu():
for k in data.keys():
try:
data[k] = data[k].cuda()
except:
continue
label = data['label'][:,:3,:,:].float()
label_ref = data['label_ref'][:,:3,:,:].float()
input_semantics = data['label'].float()
ref_semantics = data['label_ref'].float()
image = data['image']
ref = data['ref']
self_ref = data['self_ref']
return label, input_semantics, image, self_ref, ref, label_ref, ref_semantics
def get_ctx_loss(self, source, target):
contextual_style5_1 = torch.mean(self.contextual_forward_loss(source[-1], target[-1].detach())) * 8
contextual_style4_1 = torch.mean(self.contextual_forward_loss(source[-2], target[-2].detach())) * 4
contextual_style3_1 = torch.mean(self.contextual_forward_loss(F.avg_pool2d(source[-3], 2), F.avg_pool2d(target[-3].detach(), 2))) * 2
return contextual_style5_1 + contextual_style4_1 + contextual_style3_1
def compute_generator_loss(self, input_label, input_semantics, real_image, ref_label=None, ref_semantics=None, ref_image=None, self_ref=None):
G_losses = {}
generate_out = self.generate_fake(input_semantics, real_image, ref_semantics=ref_semantics, ref_image=ref_image, self_ref=self_ref)
generate_out['fake_image'] = generate_out['fake_image'].float()
weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
sample_weights = self_ref/(sum(self_ref)+1e-5)
sample_weights = sample_weights.view(-1, 1, 1, 1)
"""domain align"""
if 'loss_novgg_featpair' in generate_out and generate_out['loss_novgg_featpair'] is not None:
G_losses['no_vgg_feat'] = generate_out['loss_novgg_featpair']
"""warping cycle"""
if self.opt.weight_warp_cycle > 0:
warp_cycle = generate_out['warp_cycle']
scale_factor = ref_image.size()[-1] // warp_cycle.size()[-1]
ref = F.avg_pool2d(ref_image, scale_factor, stride=scale_factor)
G_losses['G_warp_cycle'] = F.l1_loss(warp_cycle, ref) * self.opt.weight_warp_cycle
"""warping loss"""
if self.opt.weight_warp_self > 0:
"""512x512"""
warp1, warp2, warp3, warp4 = generate_out['warp_out']
G_losses['G_warp_self'] = \
torch.mean(F.l1_loss(warp4, real_image, reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
torch.mean(F.l1_loss(warp3, F.avg_pool2d(real_image, 2, stride=2), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
torch.mean(F.l1_loss(warp2, F.avg_pool2d(real_image, 4, stride=4), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
torch.mean(F.l1_loss(warp1, F.avg_pool2d(real_image, 8, stride=8), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0
"""gan loss"""
pred_fake, pred_real = self.discriminate(input_semantics, generate_out['fake_image'], real_image)
G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.weight_gan
if not self.opt.no_ganFeat_loss:
num_D = len(pred_fake)
GAN_Feat_loss = 0.0
for i in range(num_D):
# for each discriminator
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
# for each layer output
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
GAN_Feat_loss += unweighted_loss * self.opt.weight_ganFeat / num_D
G_losses['GAN_Feat'] = GAN_Feat_loss
"""feature matching loss"""
fake_features = self.vggnet_fix(generate_out['fake_image'], ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
loss = 0
for i in range(len(generate_out['real_features'])):
loss += weights[i] * util.weighted_l1_loss(fake_features[i], generate_out['real_features'][i].detach(), sample_weights)
G_losses['fm'] = loss * self.opt.weight_vgg * self.opt.weight_fm_ratio
"""perceptual loss"""
feat_loss = util.mse_loss(fake_features[self.perceptual_layer], generate_out['real_features'][self.perceptual_layer].detach())
G_losses['perc'] = feat_loss * self.opt.weight_perceptual
"""contextual loss"""
G_losses['contextual'] = self.get_ctx_loss(fake_features, generate_out['ref_features']) * self.opt.weight_vgg * self.opt.weight_contextual
return G_losses, generate_out
def compute_discriminator_loss(self, input_semantics, real_image, GforD, label=None):
D_losses = {}
with torch.no_grad():
fake_image = GforD['fake_image'].detach()
fake_image.requires_grad_()
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) * self.opt.weight_gan
D_losses['D_real'] = self.criterionGAN(pred_real, True, for_discriminator=True) * self.opt.weight_gan
return D_losses
def encode_z(self, real_image):
mu, logvar = self.net['netE'](real_image)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_fake(self, input_semantics, real_image, ref_semantics=None, ref_image=None, self_ref=None):
generate_out = {}
generate_out['ref_features'] = self.vggnet_fix(ref_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
generate_out['real_features'] = self.vggnet_fix(real_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
with autocast(enabled=self.opt.amp):
corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
generate_out = {**generate_out, **corr_out}
return generate_out
def inference(self, input_semantics, ref_semantics=None, ref_image=None, self_ref=None, real_image=None):
generate_out = {}
with autocast(enabled=self.opt.amp):
corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
generate_out = {**generate_out, **corr_out}
return generate_out
def discriminate(self, input_semantics, fake_image, real_image):
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
real_concat = torch.cat([input_semantics, real_image], dim=1)
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
with autocast(enabled=self.opt.amp):
discriminator_out = self.net['netD'](fake_and_real)
pred_fake, pred_real = self.divide_pred(discriminator_out)
return pred_fake, pred_real
def divide_pred(self, pred):
if type(pred) == list:
fake = []
real = []
for p in pred:
fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
real.append([tensor[tensor.size(0) // 2:] for tensor in p])
else:
fake = pred[:pred.size(0) // 2]
real = pred[pred.size(0) // 2:]
return fake, real
def use_gpu(self):
return len(self.opt.gpu_ids) > 0

2
options/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

166
options/base_options.py Normal file
Просмотреть файл

@ -0,0 +1,166 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import random
import argparse
import pickle
import numpy as np
import torch
import models
import data
from util import util
class BaseOptions():
def __init__(self):
self.initialized = False
def initialize(self, parser):
# experiment specifics
parser.add_argument('--name', type=str, default='deepfashionHD', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# input/output sizes
parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
parser.add_argument('--load_size', type=int, default=256, help='Scale images to this size. The final image will be cropped to --crop_size.')
parser.add_argument('--crop_size', type=int, default=256, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
# for setting inputs
parser.add_argument('--dataroot', type=str, default='dataset/deepfashionHD')
parser.add_argument('--dataset_mode', type=str, default='deepfashionHD')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster')
parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache')
# for displays
parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
# for generator
parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
# for feature encoder
parser.add_argument('--netCorr', type=str, default='NoVGGHPM')
parser.add_argument('--nef', type=int, default=32, help='# of gen filters in first conv layer')
# for instance-wise features
parser.add_argument('--CBN_intype', type=str, default='warp_mask', help='type of CBN input for framework, warp/mask/warp_mask')
parser.add_argument('--match_kernel', type=int, default=1, help='correspondence matrix match kernel size')
parser.add_argument('--featEnc_kernel', type=int, default=3, help='kernel size in domain adaptor')
parser.add_argument('--PONO', action='store_true', help='use positional normalization ')
parser.add_argument('--PONO_C', action='store_true', help='use C normalization in corr module')
parser.add_argument('--vgg_normal_correct', action='store_true', help='if true, correct vgg normalization and replace vgg FM model with ctx model')
parser.add_argument('--use_coordconv', action='store_true', help='if true, use coordconv in CorrNet')
parser.add_argument('--video_like', action='store_true', help='useful in deepfashion')
parser.add_argument('--amp', action='store_true', help='use torch.cuda.amp')
parser.add_argument('--temperature', type=float, default=0.01)
parser.add_argument('--iteration_count', type=int, default=5)
self.initialized = True
return parser
def gather_options(self):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, unknown = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
# modify dataset-related parser options
dataset_mode = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_mode)
parser = dataset_option_setter(parser, self.isTrain)
opt, unknown = parser.parse_known_args()
# if there is opt_file, load it.
# The previous default options will be overwritten
if opt.load_from_opt_file:
parser = self.update_options_from_file(parser, opt)
opt = parser.parse_args()
self.parser = parser
return opt
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def option_file_path(self, opt, makedir=False):
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
if makedir:
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt')
return file_name
def save_options(self, opt):
file_name = self.option_file_path(opt, makedir=True)
with open(file_name + '.txt', 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
with open(file_name + '.pkl', 'wb') as opt_file:
pickle.dump(opt, opt_file)
def update_options_from_file(self, parser, opt):
new_opt = self.load_options(opt)
for k, v in sorted(vars(opt).items()):
if hasattr(new_opt, k) and v != getattr(new_opt, k):
new_val = getattr(new_opt, k)
parser.set_defaults(**{k: new_val})
return parser
def load_options(self, opt):
file_name = self.option_file_path(opt, makedir=False)
new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
return new_opt
def parse(self, save=False):
# gather options from base, train, dataset, model
opt = self.gather_options()
# train or test
opt.isTrain = self.isTrain
self.print_options(opt)
if opt.isTrain:
self.save_options(opt)
# Set semantic_nc based on the option.
# This will be convenient in many places
opt.semantic_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0)
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = list(range(len(str_ids)))
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt

20
options/test_options.py Normal file
Просмотреть файл

@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self, parser):
BaseOptions.initialize(self, parser)
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
parser.add_argument('--save_per_img', action='store_true', help='if specified, save per image')
parser.add_argument('--show_corr', action='store_true', help='if specified, save bilinear upsample correspondence')
parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256)
parser.set_defaults(serial_batches=True)
parser.set_defaults(no_flip=True)
parser.set_defaults(phase='test')
self.isTrain = False
return parser

46
options/train_options.py Normal file
Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self, parser):
BaseOptions.initialize(self, parser)
# for displays
parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
# for training
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')
# for discriminators
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
parser.add_argument('--real_reference_probability', type=float, default=0.0, help='self-supervised training probability')
parser.add_argument('--hard_reference_probability', type=float, default=0.0, help='hard reference training probability')
# training loss weights
parser.add_argument('--weight_warp_self', type=float, default=0.0, help='push warp self to ref')
parser.add_argument('--weight_warp_cycle', type=float, default=0.0, help='push warp cycle to ref')
parser.add_argument('--weight_novgg_featpair', type=float, default=10.0, help='in no vgg setting, use pair feat loss in domain adaptation')
parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
parser.add_argument('--weight_gan', type=float, default=10.0, help='weight of all loss in stage1')
parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--weight_ganFeat', type=float, default=10.0, help='weight for feature matching loss')
parser.add_argument('--which_perceptual', type=str, default='4_2', help='relu5_2 or relu4_2')
parser.add_argument('--weight_perceptual', type=float, default=0.001)
parser.add_argument('--weight_vgg', type=float, default=10.0, help='weight for vgg loss')
parser.add_argument('--weight_contextual', type=float, default=1.0, help='ctx loss weight')
parser.add_argument('--weight_fm_ratio', type=float, default=1.0, help='vgg fm loss weight comp with ctx loss')
self.isTrain = True
return parser

10
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,10 @@
torch==1.7.0
torchvision
matplotlib
pillow
imageio
numpy
pandas
scipy
scikit-image
opencv-python

Двоичные данные
slides/cocosnet_v2_slides.pdf Normal file

Двоичный файл не отображается.

46
test.py Normal file
Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from torchvision.utils import save_image
import os
import imageio
import numpy as np
import data
from util.util import mkdir
from options.test_options import TestOptions
from models.pix2pix_model import Pix2PixModel
if __name__ == '__main__':
opt = TestOptions().parse()
dataloader = data.create_dataloader(opt)
model = Pix2PixModel(opt)
if len(opt.gpu_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
model.to(opt.gpu_ids[0])
model.eval()
save_root = os.path.join(opt.checkpoints_dir, opt.name, 'test')
mkdir(save_root)
for i, data_i in enumerate(dataloader):
print('{} / {}'.format(i, len(dataloader)))
if i * opt.batchSize >= opt.how_many:
break
imgs_num = data_i['label'].shape[0]
out = model(data_i, mode='inference')
if opt.save_per_img:
try:
for it in range(imgs_num):
save_name = os.path.join(save_root, '%08d_%04d.png' % (i, it))
save_image(out['fake_image'][it:it+1], save_name, padding=0, normalize=True)
except OSError as err:
print(err)
else:
label = data_i['label'][:,:3,:,:]
imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), out['fake_image'].data.cpu()), 0)
try:
save_name = os.path.join(save_root, '%08d.png' % i)
save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
except OSError as err:
print(err)

90
train.py Normal file
Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import torch
from torchvision.utils import save_image
from options.train_options import TrainOptions
import data
from util.iter_counter import IterationCounter
from util.util import print_current_errors
from util.util import mkdir
from trainers.pix2pix_trainer import Pix2PixTrainer
if __name__ == '__main__':
# parse options
opt = TrainOptions().parse()
# print options to help debugging
print(' '.join(sys.argv))
dataloader = data.create_dataloader(opt)
len_dataloader = len(dataloader)
# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))
# create trainer for our model
trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch)
save_root = os.path.join('checkpoints', opt.name, 'train')
mkdir(save_root)
for epoch in iter_counter.training_epochs():
opt.epoch = epoch
iter_counter.record_epoch_start(epoch)
for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
iter_counter.record_one_iteration()
# Training
# train generator
if i % opt.D_steps_per_G == 0:
trainer.run_generator_one_step(data_i)
# train discriminator
trainer.run_discriminator_one_step(data_i)
if iter_counter.needs_printing():
losses = trainer.get_latest_losses()
try:
print_current_errors(opt, epoch, iter_counter.epoch_iter,
iter_counter.epoch_iter_num, losses, iter_counter.time_per_iter)
except OSError as err:
print(err)
if iter_counter.needs_displaying():
imgs_num = data_i['label'].shape[0]
if opt.dataset_mode == 'deepfashionHD':
label = data_i['label'][:,:3,:,:]
show_size = opt.display_winsize
imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), \
trainer.get_latest_generated().data.cpu(), \
data_i['image'].cpu()), 0)
try:
save_name = '%08d_%08d.png' % (epoch, iter_counter.total_steps_so_far)
save_name = os.path.join(save_root, save_name)
save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
except OSError as err:
print(err)
if iter_counter.needs_saving():
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, iter_counter.total_steps_so_far))
try:
trainer.save('latest')
iter_counter.record_current_iter()
except OSError as err:
import pdb; pdb.set_trace()
print(err)
trainer.update_learning_rate(epoch)
iter_counter.record_epoch_end()
if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, iter_counter.total_steps_so_far))
try:
trainer.save('latest')
trainer.save(epoch)
except OSError as err:
print(err)
print('Training was successfully finished.')

2
trainers/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

124
trainers/pix2pix_trainer.py Normal file
Просмотреть файл

@ -0,0 +1,124 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import copy
import sys
import torch
from models.pix2pix_model import Pix2PixModel
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self, enabled):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
class Pix2PixTrainer():
"""
Trainer creates the model and optimizers, and uses them to
updates the weights of the network while reporting losses
and the latest visuals to visualize the progress in training.
"""
def __init__(self, opt, resume_epoch=0):
self.opt = opt
self.pix2pix_model = Pix2PixModel(opt)
if len(opt.gpu_ids) > 1:
self.pix2pix_model = torch.nn.DataParallel(self.pix2pix_model, device_ids=opt.gpu_ids)
self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
else:
self.pix2pix_model.to(opt.gpu_ids[0])
self.pix2pix_model_on_one_gpu = self.pix2pix_model
self.generated = None
if opt.isTrain:
self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(opt)
self.old_lr = opt.lr
if opt.continue_train and opt.which_epoch == 'latest':
try:
load_path = os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth')
checkpoint = torch.load(load_path)
self.optimizer_G.load_state_dict(checkpoint['G'])
self.optimizer_D.load_state_dict(checkpoint['D'])
except FileNotFoundError as err:
print(err)
print('Not find optimizer state dict: ' + load_path + '. Do not load optimizer!')
self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None
self.g_losses = {}
self.d_losses = {}
self.scaler = GradScaler(enabled=self.opt.amp)
def run_generator_one_step(self, data):
self.optimizer_G.zero_grad()
g_losses, out = self.pix2pix_model(data, mode='generator')
g_loss = sum(g_losses.values()).mean()
# g_loss.backward()
self.scaler.scale(g_loss).backward()
self.scaler.unscale_(self.optimizer_G)
# self.optimizer_G.step()
self.scaler.step(self.optimizer_G)
self.scaler.update()
self.g_losses = g_losses
self.out = out
def run_discriminator_one_step(self, data):
self.optimizer_D.zero_grad()
GforD = {}
GforD['fake_image'] = self.out['fake_image']
GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg']
GforD['adaptive_feature_img'] = self.out['adaptive_feature_img']
d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD)
d_loss = sum(d_losses.values()).mean()
# d_loss.backward()
self.scaler.scale(d_loss).backward()
self.scaler.unscale_(self.optimizer_D)
# self.optimizer_D.step()
self.scaler.step(self.optimizer_D)
self.scaler.update()
self.d_losses = d_losses
def get_latest_losses(self):
return {**self.g_losses, **self.d_losses}
def get_latest_generated(self):
return self.out['fake_image']
def update_learning_rate(self, epoch):
self.update_learning_rate(epoch)
def save(self, epoch):
self.pix2pix_model_on_one_gpu.save(epoch)
if epoch == 'latest':
torch.save({'G': self.optimizer_G.state_dict(), \
'D': self.optimizer_D.state_dict(), \
'lr': self.old_lr,}, \
os.path.join(self.opt.checkpoints_dir, self.opt.name, 'optimizer.pth'))
def update_learning_rate(self, epoch):
if epoch > self.opt.niter:
lrd = self.opt.lr / self.opt.niter_decay
new_lr = self.old_lr - lrd
else:
new_lr = self.old_lr
if new_lr != self.old_lr:
new_lr_G = new_lr
new_lr_D = new_lr
else:
new_lr_G = self.old_lr
new_lr_D = self.old_lr
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = new_lr_D
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = new_lr_G
print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr

2
util/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

73
util/iter_counter.py Normal file
Просмотреть файл

@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import time
import numpy as np
# Helper class that keeps track of training iterations
class IterationCounter():
def __init__(self, opt, dataset_size):
self.opt = opt
self.dataset_size = dataset_size
self.batch_size = opt.batchSize
self.first_epoch = 1
self.total_epochs = opt.niter + opt.niter_decay
# iter number within each epoch
self.epoch_iter = 0
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
if opt.isTrain and opt.continue_train:
try:
self.first_epoch, self.epoch_iter = np.loadtxt(self.iter_record_path, delimiter=',', dtype=int)
print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
except:
print('Could not load iteration record at %s. Starting from beginning.' % self.iter_record_path)
self.epoch_iter_num = dataset_size * self.batch_size
self.total_steps_so_far = (self.first_epoch - 1) * self.epoch_iter_num + self.epoch_iter
self.continue_train_flag = opt.continue_train
# return the iterator of epochs for the training
def training_epochs(self):
return range(self.first_epoch, self.total_epochs + 1)
def record_epoch_start(self, epoch):
self.epoch_start_time = time.time()
if not self.continue_train_flag:
self.epoch_iter = 0
else:
self.continue_train_flag = False
self.last_iter_time = time.time()
self.current_epoch = epoch
def record_one_iteration(self):
current_time = time.time()
# the last remaining batch is dropped (see data/__init__.py),
# so we can assume batch size is always opt.batchSize
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
self.last_iter_time = current_time
self.total_steps_so_far += self.opt.batchSize
self.epoch_iter += self.opt.batchSize
def record_epoch_end(self):
current_time = time.time()
self.time_per_epoch = current_time - self.epoch_start_time
print('End of epoch %d / %d \t Time Taken: %d sec' %
(self.current_epoch, self.total_epochs, self.time_per_epoch))
if self.current_epoch % self.opt.save_epoch_freq == 0:
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0),
delimiter=',', fmt='%d')
print('Saved current iteration count at %s.' % self.iter_record_path)
def record_current_iter(self):
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
delimiter=',', fmt='%d')
print('Saved current iteration count at %s.' % self.iter_record_path)
def needs_saving(self):
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
def needs_printing(self):
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
def needs_displaying(self):
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize

106
util/util.py Normal file
Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import re
import argparse
from argparse import Namespace
import torch
import numpy as np
import importlib
from PIL import Image
def feature_normalize(feature_in, eps=1e-10):
feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + eps
feature_in_norm = torch.div(feature_in, feature_in_norm)
return feature_in_norm
def weighted_l1_loss(input, target, weights):
out = torch.abs(input - target)
out = out * weights.expand_as(out)
loss = out.mean()
return loss
def mse_loss(input, target=0):
return torch.mean((input - target)**2)
def vgg_preprocess(tensor, vgg_normal_correct=False):
if vgg_normal_correct:
tensor = (tensor + 1) / 2
# input is RGB tensor which ranges in [0,1]
# output is BGR tensor which ranges in [0,255]
tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
# tensor_bgr = tensor[:, [2, 1, 0], ...]
tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1)
tensor_rst = tensor_bgr_ml * 255
return tensor_rst
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace('_', '').lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = '%s_net_%s.pth' % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = '%s_net_%s.pth' % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
if not os.path.exists(save_path):
print('not find model :' + save_path + ', do not load model!')
return net
weights = torch.load(save_path)
try:
net.load_state_dict(weights)
except KeyError:
print('key error, not load!')
except RuntimeError as err:
print(err)
net.load_state_dict(weights, strict=False)
print('loaded with strict = False')
print('Load from ' + save_path)
return net
def print_current_errors(opt, epoch, i, num, errors, t):
message = '(epoch: %d, iters: %d, finish: %.2f%%, time: %.3f) ' % (epoch, i, (i/num)*100.0, t)
for k, v in errors.items():
v = v.mean().float()
message += '%s: %.3f ' % (k, v)
print(message)
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(log_name, "a") as log_file:
log_file.write('%s\n' % message)