зеркало из https://github.com/microsoft/ProDA.git
196 строки
7.7 KiB
Python
196 строки
7.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import torch
|
|
import random
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from parser_train import parser_, relative_path_to_absolute_path
|
|
|
|
from tqdm import tqdm
|
|
from data import create_dataset
|
|
from utils import get_logger
|
|
from models import adaptation_modelv2
|
|
from metrics import runningScore, averageMeter
|
|
|
|
def train(opt, logger):
|
|
torch.manual_seed(opt.seed)
|
|
torch.cuda.manual_seed(opt.seed)
|
|
np.random.seed(opt.seed)
|
|
random.seed(opt.seed)
|
|
## create dataset
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
|
|
datasets = create_dataset(opt, logger)
|
|
|
|
if opt.model_name == 'deeplabv2':
|
|
model = adaptation_modelv2.CustomModel(opt, logger)
|
|
|
|
# Setup Metrics
|
|
running_metrics_val = runningScore(opt.n_class)
|
|
time_meter = averageMeter()
|
|
|
|
# load category anchors
|
|
if opt.stage == 'stage1':
|
|
objective_vectors = torch.load(os.path.join(os.path.dirname(opt.resume_path), 'prototypes_on_{}_from_{}'.format(opt.tgt_dataset, opt.model_name)))
|
|
model.objective_vectors = torch.Tensor(objective_vectors).to(0)
|
|
|
|
# begin training
|
|
save_path = os.path.join(opt.logdir,"from_{}_to_{}_on_{}_current_model.pkl".format(opt.src_dataset, opt.tgt_dataset, opt.model_name))
|
|
model.iter = 0
|
|
start_epoch = 0
|
|
for epoch in range(start_epoch, opt.epochs):
|
|
for data_i in datasets.target_train_loader:
|
|
target_image = data_i['img'].to(device)
|
|
target_imageS = data_i['img_strong'].to(device)
|
|
target_params = data_i['params']
|
|
target_image_full = data_i['img_full'].to(device)
|
|
target_weak_params = data_i['weak_params']
|
|
|
|
target_lp = data_i['lp'].to(device) if 'lp' in data_i.keys() else None
|
|
target_lpsoft = data_i['lpsoft'].to(device) if 'lpsoft' in data_i.keys() else None
|
|
source_data = datasets.source_train_loader.next()
|
|
|
|
model.iter += 1
|
|
i = model.iter
|
|
images = source_data['img'].to(device)
|
|
labels = source_data['label'].to(device)
|
|
source_imageS = source_data['img_strong'].to(device)
|
|
source_params = source_data['params']
|
|
|
|
start_ts = time.time()
|
|
|
|
model.train(logger=logger)
|
|
if opt.freeze_bn:
|
|
model.freeze_bn_apply()
|
|
model.optimizer_zerograd()
|
|
|
|
if opt.stage == 'warm_up':
|
|
loss_GTA, loss_G, loss_D = model.step_adv(images, labels, target_image, source_imageS, source_params)
|
|
elif opt.stage == 'stage1':
|
|
loss, loss_CTS, loss_consist = model.step(images, labels, target_image, target_imageS, target_params, target_lp,
|
|
target_lpsoft, target_image_full, target_weak_params)
|
|
else:
|
|
loss_GTA, loss = model.step_distillation(images, labels, target_image, target_imageS, target_params, target_lp)
|
|
|
|
time_meter.update(time.time() - start_ts)
|
|
|
|
#print(i)
|
|
if (i + 1) % opt.print_interval == 0:
|
|
if opt.stage == 'warm_up':
|
|
fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}] loss_GTA: {:.4f} loss_G: {:.4f} loss_D: {:.4f} Time/Image: {:.4f}"
|
|
print_str = fmt_str.format(epoch+1, opt.epochs, i + 1, opt.train_iters, loss_GTA, loss_G, loss_D, time_meter.avg / opt.bs)
|
|
elif opt.stage == 'stage1':
|
|
fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}] loss: {:.4f} loss_CTS: {:.4f} loss_consist: {:.4f} Time/Image: {:.4f}"
|
|
print_str = fmt_str.format(epoch+1, opt.epochs, i + 1, opt.train_iters, loss, loss_CTS, loss_consist, time_meter.avg / opt.bs)
|
|
else:
|
|
fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}] loss_GTA: {:.4f} loss: {:.4f} Time/Image: {:.4f}"
|
|
print_str = fmt_str.format(epoch+1, opt.epochs, i + 1, opt.train_iters, loss_GTA, loss, time_meter.avg / opt.bs)
|
|
print(print_str)
|
|
logger.info(print_str)
|
|
time_meter.reset()
|
|
|
|
# evaluation
|
|
if (i + 1) % opt.val_interval == 0:
|
|
validation(model, logger, datasets, device, running_metrics_val, iters = model.iter, opt=opt)
|
|
torch.cuda.empty_cache()
|
|
logger.info('Best iou until now is {}'.format(model.best_iou))
|
|
|
|
model.scheduler_step()
|
|
|
|
def validation(model, logger, datasets, device, running_metrics_val, iters, opt=None):
|
|
iters = iters
|
|
_k = -1
|
|
for v in model.optimizers:
|
|
_k += 1
|
|
for param_group in v.param_groups:
|
|
_learning_rate = param_group.get('lr')
|
|
logger.info("learning rate is {} for {} net".format(_learning_rate, model.nets[_k].__class__.__name__))
|
|
model.eval(logger=logger)
|
|
torch.cuda.empty_cache()
|
|
val_datset = datasets.target_valid_loader
|
|
#val_datset = datasets.target_train_loader
|
|
with torch.no_grad():
|
|
validate(val_datset, device, model, running_metrics_val)
|
|
|
|
score, class_iou = running_metrics_val.get_scores()
|
|
for k, v in score.items():
|
|
print(k, v)
|
|
logger.info('{}: {}'.format(k, v))
|
|
|
|
for k, v in class_iou.items():
|
|
logger.info('{}: {}'.format(k, v))
|
|
|
|
running_metrics_val.reset()
|
|
|
|
torch.cuda.empty_cache()
|
|
state = {}
|
|
_k = -1
|
|
for net in model.nets:
|
|
_k += 1
|
|
new_state = {
|
|
"model_state": net.state_dict(),
|
|
#"optimizer_state": model.optimizers[_k].state_dict(),
|
|
#"scheduler_state": model.schedulers[_k].state_dict(),
|
|
"objective_vectors": model.objective_vectors,
|
|
}
|
|
state[net.__class__.__name__] = new_state
|
|
state['iter'] = iters + 1
|
|
state['best_iou'] = score["Mean IoU : \t"]
|
|
save_path = os.path.join(opt.logdir,"from_{}_to_{}_on_{}_current_model.pkl".format(opt.src_dataset, opt.tgt_dataset, opt.model_name))
|
|
torch.save(state, save_path)
|
|
|
|
if score["Mean IoU : \t"] >= model.best_iou:
|
|
torch.cuda.empty_cache()
|
|
model.best_iou = score["Mean IoU : \t"]
|
|
state = {}
|
|
_k = -1
|
|
for net in model.nets:
|
|
_k += 1
|
|
new_state = {
|
|
"model_state": net.state_dict(),
|
|
#"optimizer_state": model.optimizers[_k].state_dict(),
|
|
#"scheduler_state": model.schedulers[_k].state_dict(),
|
|
"objective_vectors": model.objective_vectors,
|
|
}
|
|
state[net.__class__.__name__] = new_state
|
|
state['iter'] = iters + 1
|
|
state['best_iou'] = model.best_iou
|
|
save_path = os.path.join(opt.logdir,"from_{}_to_{}_on_{}_best_model.pkl".format(opt.src_dataset, opt.tgt_dataset, opt.model_name))
|
|
torch.save(state, save_path)
|
|
return score["Mean IoU : \t"]
|
|
|
|
def validate(valid_loader, device, model, running_metrics_val):
|
|
for data_i in tqdm(valid_loader):
|
|
|
|
images_val = data_i['img'].to(device)
|
|
labels_val = data_i['label'].to(device)
|
|
|
|
out = model.BaseNet_DP(images_val)
|
|
|
|
outputs = F.interpolate(out['out'], size=images_val.size()[2:], mode='bilinear', align_corners=True)
|
|
#val_loss = loss_fn(input=outputs, target=labels_val)
|
|
|
|
pred = outputs.data.max(1)[1].cpu().numpy()
|
|
gt = labels_val.data.cpu().numpy()
|
|
running_metrics_val.update(gt, pred)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="config")
|
|
parser = parser_(parser)
|
|
opt = parser.parse_args()
|
|
|
|
opt = relative_path_to_absolute_path(opt)
|
|
|
|
print('RUNDIR: {}'.format(opt.logdir))
|
|
if not os.path.exists(opt.logdir):
|
|
os.makedirs(opt.logdir)
|
|
|
|
logger = get_logger(opt.logdir)
|
|
|
|
train(opt, logger)
|