This commit is contained in:
apsyx 2021-02-06 02:08:14 -08:00
Родитель 050736efe5
Коммит 1a68fd743c
8 изменённых файлов: 1073 добавлений и 0 удалений

153
CIFAR/data_helper_cifar.py Normal file
Просмотреть файл

@ -0,0 +1,153 @@
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from utils import DataIterator
def prepare_data(gold_fraction, corruption_prob, corruption_type, args):
if args.use_mwnet_loader:
return prepare_data_mwnet(gold_fraction, corruption_prob, corruption_type, args)
else:
return prepare_data_mlc(gold_fraction, corruption_prob, corruption_type, args)
def prepare_data_mwnet(gold_fraction, corruption_prob, corruption_type, args):
from load_corrupted_data_mlg import CIFAR10, CIFAR100
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
if True: # no augment as used by mwnet
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
args.num_meta = int(50000 * gold_fraction)
if args.dataset == 'cifar10':
num_classes = 10
train_data_meta = CIFAR10(
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True)
train_data = CIFAR10(
root=args.data_path, train=True, meta=False, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
test_data = CIFAR10(root=args.data_path, train=False, transform=test_transform, download=True)
valid_data = CIFAR10(
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True)
elif args.dataset == 'cifar100':
num_classes = 100
train_data_meta = CIFAR100(
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True)
train_data = CIFAR100(
root=args.data_path, train=True, meta=False, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
test_data = CIFAR100(root=args.data_path, train=False, transform=test_transform, download=True)
valid_data = CIFAR100(
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
corruption_type=args.corruption_type, transform=train_transform, download=True)
train_gold_loader = DataIterator(torch.utils.data.DataLoader(train_data_meta, batch_size=args.bs, shuffle=True,
num_workers=args.prefetch, pin_memory=True))
train_silver_loader = torch.utils.data.DataLoader(train_data, batch_size=args.bs, shuffle=True,
num_workers=args.prefetch, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.bs, shuffle=True,
num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.bs, shuffle=False,
num_workers=args.prefetch, pin_memory=True)
return train_gold_loader, train_silver_loader, valid_loader, test_loader, num_classes
def prepare_data_mlc(gold_fraction, corruption_prob, corruption_type, args):
from load_corrupted_data import CIFAR10, CIFAR100
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
train_transform = transforms.Compose(
[transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)])
test_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)])
# since cifar10 and cifar100 have no official validation split, use gold as valid also
if args.dataset == 'cifar10':
train_data_gold = CIFAR10(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
train_data_silver = CIFAR10(
args.data_path, True, False, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, seed=args.seed, distinguish_gold=False, weaklabel=args.weaklabel) # note here for the change
train_data_gold_deterministic = CIFAR10(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=test_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, distinguish_gold=False, seed=args.seed)
test_data = CIFAR10(args.data_path, train=False, transform=test_transform, download=True, distinguish_gold=False, seed=args.seed)
# same as gold
valid_data = CIFAR10(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
num_classes = 10
elif args.dataset == 'cifar100':
train_data_gold = CIFAR100(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
train_data_silver = CIFAR100(
args.data_path, True, False, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, seed=args.seed, distinguish_gold=False,
weaklabel=args.weaklabel) # note the weaklabel arg
train_data_gold_deterministic = CIFAR100(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=test_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, distinguish_gold=False, seed=args.seed)
test_data = CIFAR100(args.data_path, train=False, transform=test_transform, download=True, distinguish_gold=False, seed=args.seed)
# same as gold
valid_data = CIFAR100(
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
num_classes = 100
gold_sampler = None
silver_sampler = None
valid_sampler = None
test_sampler = None
batch_size = args.bs
train_gold_loader = DataIterator(torch.utils.data.DataLoader(
train_data_gold, batch_size=batch_size, shuffle=(gold_sampler is None),
num_workers=args.prefetch, pin_memory=True, sampler=gold_sampler))
train_silver_loader =torch.utils.data.DataLoader(
train_data_silver, batch_size=batch_size, shuffle=(silver_sampler is None),
num_workers=args.prefetch, pin_memory=True, sampler=silver_sampler)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=(valid_sampler is None), num_workers=args.prefetch, pin_memory=True, sampler=valid_sampler)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=(test_sampler is None), num_workers=args.prefetch, pin_memory=True, sampler=test_sampler)
return train_gold_loader, train_silver_loader, valid_loader, test_loader, num_classes

162
CIFAR/resnet.py Normal file
Просмотреть файл

@ -0,0 +1,162 @@
'''
Credit to https://github.com/akamaster
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
The implementation and structure of this file is hugely influenced by [2]
which is implemented for ImageNet and doesn't have option A for identity.
Moreover, most of the implementations on the web is copy-paste from
torchvision's resnet and has wrong number of params.
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
number of layers and parameters:
name | layers | params
ResNet20 | 20 | 0.27M
ResNet32 | 32 | 0.46M
ResNet44 | 44 | 0.66M
ResNet56 | 56 | 0.85M
ResNet110 | 110 | 1.7M
ResNet1202| 1202 | 19.4m
which this implementation indeed has.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
If you use this implementation in you work, please don't forget to mention the
author, Yerlan Idelbayev.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
def _weights_init(m):
classname = m.__class__.__name__
#print(classname)
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
"""
For CIFAR10 ResNet paper uses option A.
"""
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
self.apply(_weights_init)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, return_h=False):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
hidden = out.view(out.size(0), -1)
out = self.linear(hidden)
if return_h:
return out, hidden
else:
return out
def resnet20():
return ResNet(BasicBlock, [3, 3, 3])
def resnet32(num_classes=10):
return ResNet(BasicBlock, [5, 5, 5], num_classes)
def resnet44():
return ResNet(BasicBlock, [7, 7, 7])
def resnet56():
return ResNet(BasicBlock, [9, 9, 9])
def resnet110():
return ResNet(BasicBlock, [18, 18, 18])
def resnet1202():
return ResNet(BasicBlock, [200, 200, 200])
def test(net):
import numpy as np
total_params = 0
for x in filter(lambda p: p.requires_grad, net.parameters()):
total_params += np.prod(x.data.numpy().shape)
print("Total number of params", total_params)
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
if __name__ == "__main__":
for net_name in __all__:
if net_name.startswith('resnet'):
print(net_name)
test(globals()[net_name]())
print()

26
logger.py Normal file
Просмотреть файл

@ -0,0 +1,26 @@
import logging
def get_logger(filename, local_rank):
formatter = logging.Formatter(fmt='[%(asctime)s %(levelname)s] %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.INFO)
logger.propagate = False
if filename is not None and local_rank <=0: # only log to file for first GPU
f_handler = logging.FileHandler(filename, 'a')
f_handler.setLevel(logging.INFO)
f_handler.setFormatter(formatter)
logger.addHandler(f_handler)
stdout_handler = logging.StreamHandler()
stdout_handler.setFormatter(formatter)
stdout_handler.setLevel(logging.INFO)
logger.addHandler(stdout_handler)
else: # null handlers for other GPUs
null_handler = logging.NullHandler()
null_handler.setLevel(logging.INFO)
logger.addHandler(null_handler)
return logger

425
main.py Normal file
Просмотреть файл

@ -0,0 +1,425 @@
import numpy as np
import pickle
import copy
import sys
import argparse
import logging
import os
from logger import get_logger
from tqdm import tqdm
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from mlc import step_hmlc_K
from mlc_utils import clone_parameters, tocuda, DummyScheduler
from models import *
from meta_models import *
parser = argparse.ArgumentParser(description='MLC Training Framework')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'clothing1m'], default='cifar10')
parser.add_argument('--method', default='hmlc_K_mix', type=str, choices=['hmlc_K_mix', 'hmlc_K'])
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--data_seed', type=int, default=1)
parser.add_argument('--epochs', '-e', type=int, default=75, help='Number of epochs to train.')
parser.add_argument('--num_iterations', default=100000, type=int)
parser.add_argument('--every', default=100, type=int, help='Eval interval (default: 100 iters)')
parser.add_argument('--bs', default=32, type=int, help='batch size')
parser.add_argument('--test_bs', default=100, type=int, help='batch size')
parser.add_argument('--gold_bs', type=int, default=32)
parser.add_argument('--cls_dim', type=int, default=64, help='Label embedding dim (Default: 64)')
parser.add_argument('--grad_clip', default=0.0, type=float, help='max grad norm (default: 0, no clip)')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer')
parser.add_argument('--main_lr', default=3e-4, type=float, help='lr for main net')
parser.add_argument('--meta_lr', default=3e-5, type=float, help='lr for meta net')
parser.add_argument('--optimizer', default='adam', type=str, choices=['adam', 'sgd', 'adadelta'])
parser.add_argument('--opt_eps', default=1e-8, type=float, help='eps for optimizers')
#parser.add_argument('--tau', default=1, type=float, help='tau')
parser.add_argument('--wdecay', default=5e-4, type=float, help='weight decay (default: 5e-4)')
# noise parameters
parser.add_argument('--corruption_type', default='unif', type=str, choices=['unif', 'flip'])
parser.add_argument('--corruption_level', default='-1', type=float, help='Corruption level')
parser.add_argument('--gold_fraction', default='-1', type=float, help='Gold fraction')
parser.add_argument('--sparsemax', default=False, action='store_true', help='Use softmax instead of softmax for meta model (default: False)')
parser.add_argument('--tie', default=False, action='store_true', help='Tie label embedding to the output classifier output embedding of metanet (default: False)')
parser.add_argument('--runid', default='exp', type=str)
parser.add_argument('--queue_size', default=1, type=int, help='Number of iterations before to compute mean loss_g')
############## LOOK-AHEAD GRADIENT STEPS FOR MLC ##################
parser.add_argument('--gradient_steps', default=1, type=int, help='Number of look-ahead gradient steps for meta-gradient (default: 1)')
# CIFAR
# Positional arguments
parser.add_argument('--data_path', default='data', type=str, help='Root for the datasets.')
# Optimization options
parser.add_argument('--nosgdr', default=False, action='store_true', help='Turn off SGDR.')
# Acceleration
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
# i/o
parser.add_argument('--logdir', type=str, default='runs', help='Log folder.')
parser.add_argument('--local_rank', type=int, default=-1, help='local rank (-1 for local training)')
args = parser.parse_args()
# //////////////////////////// set logging ///////////////////////
filename = '_'.join([args.dataset, args.method, args.corruption_type, args.runid, str(args.epochs), str(args.seed), str(args.data_seed)])
logfile = 'logs/' + filename + '.log'
logger = get_logger(logfile, args.local_rank)
# //////////////////////////////////////////////////////////////////
logger.info(args)
logger.info('CUDA available:' + str(torch.cuda.is_available()))
# cuda set up
torch.cuda.set_device(0) # local GPU
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# loss function for hard label and soft labels
hard_loss_f = F.cross_entropy
from mlc_utils import soft_cross_entropy as soft_loss_f
# //////////////////////// defining model ////////////////////////
def get_data(dataset, gold_fraction, corruption_prob, get_C):
if dataset == 'cifar10' or dataset == 'cifar100':
sys.path.append('CIFAR')
from data_helper_cifar import prepare_data
args.use_mwnet_loader = True # use exactly the same loader as in the mwnet paper
logger.info('================= Use the same dataloader as in MW-Net =========================')
return prepare_data(gold_fraction, corruption_prob, get_C, args)
elif dataset == 'clothing1m':
sys.path.append('CLOTHING1M')
from data_helper_clothing1m import prepare_data
return prepare_data(args)
def build_models(dataset, num_classes):
cls_dim = args.cls_dim # input label embedding dimension
if dataset in ['cifar10', 'cifar100']:
from CIFAR.resnet import resnet32
# main net
model = resnet32(num_classes)
main_net = model
# meta net
hx_dim = 64 #0 if isinstance(model, WideResNet) else 64 # 64 for resnet-32
meta_net = MetaNet(hx_dim, cls_dim, 128, num_classes, args)
elif dataset == 'clothing1m': # use pretrained ResNet-50 model
model = ResNet50(num_classes)
main_net = model
hx_dim = 2048 # from resnet50
meta_net = MetaNet(2048, cls_dim, 128, num_classes, args)
main_net = main_net.cuda()
meta_net = meta_net.cuda()
logger.info('========== Main model ==========')
logger.info(model)
logger.info('========== Meta model ==========')
logger.info(meta_net)
return main_net, meta_net
def setup_training(main_net, meta_net, exp_id=None):
# ============== setting up from scratch ===================
# set up optimizers and schedulers
# meta net optimizer
optimizer = torch.optim.Adam(meta_net.parameters(), lr=args.meta_lr,
weight_decay=0, #args.wdecay, # meta should have wdecay or not??
amsgrad=True, eps=args.opt_eps)
scheduler = DummyScheduler(optimizer)
# main net optimizer
main_params = main_net.parameters()
if args.optimizer == 'adam':
main_opt = torch.optim.Adam(main_params, lr=args.main_lr, weight_decay=args.wdecay, amsgrad=True, eps=args.opt_eps)
elif args.optimizer == 'sgd':
main_opt = torch.optim.SGD(main_params, lr=args.main_lr, weight_decay=args.wdecay, momentum=args.momentum)
if args.dataset in ['cifar10', 'cifar100']:
# follow MW-Net setting
main_schdlr = torch.optim.lr_scheduler.MultiStepLR(main_opt, milestones=[80,100], gamma=0.1)
elif args.dataset in ['clothing1m']:
main_schdlr = torch.optim.lr_scheduler.MultiStepLR(main_opt, milestones=[5], gamma=0.1)
else:
main_schdlr = DummyScheduler(main_opt)
last_epoch = -1
return main_net, meta_net, main_opt, optimizer, main_schdlr, scheduler, last_epoch
def uniform_mix_C(num_classes, mixing_ratio):
'''
returns a linear interpolation of a uniform matrix and an identity matrix
'''
return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
(1 - mixing_ratio) * np.eye(num_classes)
def flip_labels_C(num_classes, corruption_prob):
'''
returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
concentrated in only one other entry for each row
'''
np.random.seed(args.seed)
C = np.eye(num_classes) * (1 - corruption_prob)
row_indices = np.arange(num_classes)
for i in range(num_classes):
C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
return C
# //////////////////////// run experiments ////////////////////////
def run():
corruption_fnctn = uniform_mix_C if args.corruption_type == 'unif' else flip_labels_C
filename = '_'.join([args.dataset, args.method, args.corruption_type, args.runid, str(args.epochs), str(args.seed), str(args.data_seed)])
results = {}
# 100 labels per class
gf_dict = {'yelp2': 200.0 / 560000,
'yelp5': 500.0 / 650000,
'amazon2': 200.0 / 3600000,
'amazon5': 500.0 / 3000000,
'dbpedia': 1400.0 / 560000,
'yahoo': 700.0 / 1400000,
'imdb2': 200.0 / 25000,
'ag': 400.0 / 120000,
}
# revisit this
gold_fractions = [0.05, 0.001, 0.01]
if args.gold_fraction != -1:
assert args.gold_fraction >=0 and args.gold_fraction <=1, 'Wrong gold fraction!'
gold_fractions = [args.gold_fraction]
corruption_levels = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
if args.corruption_level != -1: # specied one corruption_level
assert args.corruption_level >= 0 and args.corruption_level <=1, 'Wrong noise level!'
corruption_levels = [args.corruption_level]
for gold_fraction in gold_fractions:
results[gold_fraction] = {}
for corruption_level in corruption_levels:
# //////////////////////// load data //////////////////////////////
# use data_seed her
gold_loader, silver_loader, valid_loader, test_loader, num_classes = get_data(args.dataset, gold_fraction, corruption_level, corruption_fnctn)
# //////////////////////// build main_net and meta_net/////////////
main_net, meta_net = build_models(args.dataset, num_classes)
# //////////////////////// train and eval model ///////////////////
exp_id = '_'.join([filename, str(gold_fraction), str(corruption_level)])
test_acc, baseline_acc = train_and_test(main_net, meta_net, gold_loader, silver_loader, valid_loader, test_loader, exp_id)
results[gold_fraction][corruption_level] = {}
results[gold_fraction][corruption_level]['method'] = test_acc
results[gold_fraction][corruption_level]['baseline'] = baseline_acc
logger.info(' '.join(['Gold fraction:', str(gold_fraction), '| Corruption level:', str(corruption_level),
'| Method acc:', str(results[gold_fraction][corruption_level]['method']),
'| Baseline acc:', str(results[gold_fraction][corruption_level]['baseline'])]))
logger.info('')
with open('out/' + filename, 'wb') as file:
pickle.dump(results, file)
logger.info("Dumped results_ours in file: " + filename)
def test(main_net, test_loader): # this could be eval or test
# //////////////////////// evaluate method ////////////////////////
correct = torch.zeros(1).cuda()
nsamples = torch.zeros(1).cuda()
# forward
main_net.eval()
for idx, (*data, target) in enumerate(test_loader):
data, target = tocuda(data), tocuda(target)
# forward
with torch.no_grad():
output = main_net(data)
# accuracy
pred = output.data.max(1)[1]
correct += pred.eq(target.data).sum().item()
nsamples += len(target)
test_acc = (correct / nsamples).item()
# set back to train
main_net.train()
return test_acc
####################################################################################################
### training code
####################################################################################################
def train_and_test(main_net, meta_net, gold_loader, silver_loader, valid_loader, test_loader, exp_id=None):
writer = SummaryWriter(args.logdir + '/' + exp_id)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
main_net, meta_net, main_opt, optimizer, main_schdlr, scheduler, last_epoch = setup_training(main_net, meta_net, exp_id)
# //////////////////////// switching on training mode ////////////////////////
meta_net.train()
main_net.train()
# set up statistics
best_params = None
best_main_opt_sd = None
best_main_schdlr_sd = None
best_meta_params = None
best_meta_opt_sd = None
best_meta_schdlr_sd = None
best_val_metric = float('inf')
val_metric_queue = deque()
# set done
args.dw_prev = [0 for param in meta_net.parameters()] # 0 for previous iteration
args.steps = 0
for epoch in tqdm(range(last_epoch+1, args.epochs)):# change to epoch iteration
logger.info('Epoch %d:' % epoch)
for i, (*data_s, target_s) in enumerate(silver_loader):
*data_g, target_g = next(gold_loader)#.next()
data_g, target_g = tocuda(data_g), tocuda(target_g)
data_s, target_s_ = tocuda(data_s), tocuda(target_s)
# bi-level optimization stage
eta = main_schdlr.get_lr()[0]
if args.method == 'hmlc_K':
loss_g, loss_s = step_hmlc_K(main_net, main_opt, hard_loss_f,
meta_net, optimizer, soft_loss_f,
data_s, target_s_, data_g, target_g,
None, None,
eta, args)
elif args.method == 'hmlc_K_mix':
# split the clean set to two, one for training and the other for meta-evaluation
gbs = int(target_g.size(0) / 2)
if type(data_g) is list:
data_c = [x[gbs:] for x in data_g]
data_g = [x[:gbs] for x in data_g]
else:
data_c = data_g[gbs:]
data_g = data_g[:gbs]
target_c = target_g[gbs:]
target_g = target_g[:gbs]
loss_g, loss_s = step_hmlc_K(main_net, main_opt, hard_loss_f,
meta_net, optimizer, soft_loss_f,
data_s, target_s_, data_g, target_g,
data_c, target_c,
eta, args)
args.steps += 1
if i % args.every == 0:
writer.add_scalar('train/loss_g', loss_g.item(), args.steps)
writer.add_scalar('train/loss_s', loss_s.item(), args.steps)
''' get entropy of predictions from meta-net '''
logit_s, x_s_h = main_net(data_s, return_h=True)
pseudo_target_s = meta_net(x_s_h.detach(), target_s_).detach()
entropy = -(pseudo_target_s * torch.log(pseudo_target_s+1e-10)).sum(-1).mean()
writer.add_scalar('train/meta_entropy', entropy.item(), args.steps)
main_lr = main_schdlr.get_lr()[0]
meta_lr = scheduler.get_lr()[0]
writer.add_scalar('train/main_lr', main_lr, args.steps)
writer.add_scalar('train/meta_lr', meta_lr, args.steps)
writer.add_scalar('train/gradient_steps', args.gradient_steps, args.steps)
logger.info('Iteration %d loss_s: %.4f\tloss_g: %.4f\tMeta entropy: %.3f\tMain LR: %.8f\tMeta LR: %.8f' %( i, loss_s.item(), loss_g.item(), entropy.item(), main_lr, meta_lr))
# PER EPOCH PROCESSING
# lr scheduler
main_schdlr.step()
#scheduler.step()
# evaluation on validation set
val_acc = test(main_net, valid_loader)
test_acc = test(main_net, test_loader)
logger.info('Val acc: %.4f\tTest acc: %.4f' % (val_acc, test_acc))
if args.local_rank <=0: # single GPU or GPU 0
writer.add_scalar('train/val_acc', val_acc, epoch)
writer.add_scalar('test/test_acc', test_acc, epoch)
if len(val_metric_queue) == args.queue_size: # keep at most this number of records
# remove the oldest record
val_metric_queue.popleft()
val_metric_queue.append(-val_acc)
avg_val_metric = sum(list(val_metric_queue)) / len(val_metric_queue)
if avg_val_metric < best_val_metric:
best_val_metric = avg_val_metric
best_params = copy.deepcopy(main_net.state_dict())
best_main_opt_sd = copy.deepcopy(main_opt.state_dict())
best_main_schdlr_sd = copy.deepcopy(main_schdlr.state_dict())
best_meta_params = copy.deepcopy(meta_net.state_dict())
best_meta_opt_sd = copy.deepcopy(optimizer.state_dict())
best_meta_schdlr_sd = copy.deepcopy(scheduler.state_dict())
# dump best to file also
####################### save best models so far ###################
logger.info('Saving best models...')
torch.save({
'epoch': epoch,
'val_metric': best_val_metric,
'main_net': best_params,
'main_opt': best_main_opt_sd,
'main_schdlr': best_main_schdlr_sd,
'meta_net': best_meta_params,
'meta_opt': best_meta_opt_sd,
'meta_schdlr': best_meta_schdlr_sd
}, 'models/%s_best.pth' % exp_id)
writer.add_scalar('train/val_acc_best', -best_val_metric, epoch) # write current best val_acc to tensorboard
# //////////////////////// evaluating method ////////////////////////
main_net.load_state_dict(best_params)
test_acc = test(main_net, test_loader) # evaluate best params picked from validation
writer.add_scalar('test/acc', test_acc, args.steps) # this test_acc should be roughly the best as it's taken from the best iteration
logger.info('Test acc: %.4f' % test_acc)
return test_acc, 0
if __name__ == '__main__':
run()

67
meta_models.py Normal file
Просмотреть файл

@ -0,0 +1,67 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MetaNet(nn.Module):
def __init__(self, hx_dim, cls_dim, h_dim, num_classes, args):
super().__init__()
self.args = args
self.num_classes = num_classes
self.in_class = self.num_classes
self.hdim = h_dim
self.cls_emb = nn.Embedding(self.in_class, cls_dim)
in_dim = hx_dim + cls_dim
self.net = nn.Sequential(
nn.Linear(in_dim, self.hdim),
nn.Tanh(),
nn.Linear(self.hdim, self.hdim),
nn.Tanh(),
nn.Linear(self.hdim, num_classes, bias=(not self.args.tie))
)
if self.args.sparsemax:
from sparsemax import Sparsemax
self.sparsemax = Sparsemax(-1)
self.init_weights()
if self.args.tie:
print ('Tying cls emb to output cls weight')
self.net[-1].weight = self.cls_emb.weight
def init_weights(self):
nn.init.xavier_uniform_(self.cls_emb.weight)
nn.init.xavier_normal_(self.net[0].weight)
nn.init.xavier_normal_(self.net[2].weight)
nn.init.xavier_normal_(self.net[4].weight)
self.net[0].bias.data.zero_()
self.net[2].bias.data.zero_()
if not self.args.tie:
assert self.in_class == self.num_classes, 'In and out classes conflict!'
self.net[4].bias.data.zero_()
def get_alpha(self):
return torch.zeros(1)
def forward(self, hx, y):
bs = hx.size(0)
y_emb = self.cls_emb(y)
hin = torch.cat([hx, y_emb], dim=-1)
logit = self.net(hin)
if self.args.sparsemax:
out = self.sparsemax(logit) # test sparsemax
else:
out = F.softmax(logit, -1)
return out

141
mlc.py Normal file
Просмотреть файл

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
@torch.no_grad()
def update_params(params, grads, eta, opt, args, deltaonly=False, return_s=False):
if isinstance(opt, torch.optim.SGD):
return update_params_sgd(params, grads, eta, opt, args, deltaonly, return_s)
else:
raise NotImplementedError('Non-supported main model optimizer type!')
# be aware that the opt state dict returns references, hence take care not to
# modify them
def update_params_sgd(params, grads, eta, opt, args, deltaonly, return_s=False):
# supports SGD-like optimizers
ans = []
if return_s:
ss = []
wdecay = opt.defaults['weight_decay']
momentum = opt.defaults['momentum']
dampening = opt.defaults['dampening']
nesterov = opt.defaults['nesterov']
for i, param in enumerate(params):
dparam = grads[i] + param * wdecay # s=1
s = 1
if momentum > 0:
try:
moment = opt.state[param]['momentum_buffer'] * momentum
except:
moment = torch.zeros_like(param)
moment.add_(dparam, alpha=1. -dampening) # s=1.-dampening
if nesterov:
dparam = dparam + momentum * moment # s= 1+momentum*(1.-dampening)
s = 1 + momentum*(1.-dampening)
else:
dparam = moment # s=1.-dampening
s = 1.-dampening
if deltaonly:
ans.append(- dparam * eta)
else:
ans.append(param - dparam * eta)
if return_s:
ss.append(s*eta)
if return_s:
return ans, ss
else:
return ans
# ============== mlc step procedure debug with features (gradient-stopped) from main model ===========
#
# METANET uses the last K-1 steps from main model and imagine one additional step ahead
# to compose a pool of actual K steps from the main model
#
#
def step_hmlc_K(main_net, main_opt, hard_loss_f,
meta_net, meta_opt, soft_loss_f,
data_s, target_s, data_g, target_g,
data_c, target_c,
eta, args):
# compute gw for updating meta_net
logit_g = main_net(data_g)
loss_g = hard_loss_f(logit_g, target_g)
gw = torch.autograd.grad(loss_g, main_net.parameters())
# given current meta net, get corrected label
logit_s, x_s_h = main_net(data_s, return_h=True)
pseudo_target_s = meta_net(x_s_h.detach(), target_s)
loss_s = soft_loss_f(logit_s, pseudo_target_s)
if data_c is not None:
bs1 = target_s.size(0)
bs2 = target_c.size(0)
logit_c = main_net(data_c)
loss_s2 = hard_loss_f(logit_c, target_c)
loss_s = (loss_s * bs1 + loss_s2 * bs2 ) / (bs1+bs2)
f_param_grads = torch.autograd.grad(loss_s, main_net.parameters(), create_graph=True)
f_params_new, dparam_s = update_params(main_net.parameters(), f_param_grads, eta, main_opt, args, return_s=True)
# 2. set w as w'
f_param = []
for i, param in enumerate(main_net.parameters()):
f_param.append(param.data.clone())
param.data = f_params_new[i].data # use data only as f_params_new has graph
# training loss Hessian approximation
Hw = 1 # assume to be identity
# 3. compute d_w' L_{D}(w')
logit_g = main_net(data_g)
loss_g = hard_loss_f(logit_g, target_g)
gw_prime = torch.autograd.grad(loss_g, main_net.parameters())
# 3.5 compute discount factor gw_prime * (I-LH) * gw.t() / |gw|^2
tmp1 = [(1-Hw*dparam_s[i]) * gw_prime[i] for i in range(len(dparam_s))]
gw_norm2 = (_concat(gw).norm())**2
tmp2 = [gw[i]/gw_norm2 for i in range(len(gw))]
gamma = torch.dot(_concat(tmp1), _concat(tmp2))
# because of dparam_s, need to scale up/down f_params_grads_prime for proxy_g/loss_g
Lgw_prime = [ dparam_s[i] * gw_prime[i] for i in range(len(dparam_s))]
proxy_g = -torch.dot(_concat(f_param_grads), _concat(Lgw_prime))
# back prop on alphas
meta_opt.zero_grad()
proxy_g.backward()
# accumulate discounted iterative gradient
for i, param in enumerate(meta_net.parameters()):
if param.grad is not None:
param.grad.add_(gamma * args.dw_prev[i])
args.dw_prev[i] = param.grad.clone()
if (args.steps+1) % (args.gradient_steps)==0: # T steps proceeded by main_net
meta_opt.step()
args.dw_prev = [0 for param in meta_net.parameters()] # 0 to reset
# modify to w, and then do actual update main_net
for i, param in enumerate(main_net.parameters()):
param.data = f_param[i]
param.grad = f_param_grads[i].data
main_opt.step()
return loss_g, loss_s

71
mlc_utils.py Normal file
Просмотреть файл

@ -0,0 +1,71 @@
import torch
import torch.nn.functional as F
def save_checkpoint(state, filename):
torch.save(state, filename)
class DummyScheduler(torch.optim.lr_scheduler._LRScheduler):
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
lrs.append(param_group['lr'])
return lrs
def step(self, epoch=None):
pass
def tocuda(data):
if type(data) is list:
if len(data) == 1:
return data[0].cuda()
else:
return [x.cuda() for x in data]
else:
return data.cuda()
'''
def net_grad_norm_max(model, p):
grad_norms = [x.grad.data.norm(p).item() for x in model.parameters()]
return max(grad_norms)
'''
def clone_parameters(model):
assert isinstance(model, torch.nn.Module), 'Wrong model type'
params = model.named_parameters()
f_params_dict = {}
f_params = []
for idx, (name, param) in enumerate(params):
new_param = torch.nn.Parameter(param.data.clone())
f_params_dict[name] = idx
f_params.append(new_param)
return f_params, f_params_dict
# target differentiable version of F.cross_entropy
def soft_cross_entropy(logit, pseudo_target, reduction='mean'):
loss = -(pseudo_target * F.log_softmax(logit, -1)).sum(-1)
if reduction == 'mean':
return loss.mean()
elif reduction == 'none':
return loss
elif reduction == 'sum':
return loss.sum()
else:
raise NotImplementedError('Invalid reduction: %s' % reduction)
# test code for soft_cross_entropy
if __name__ == '__main__':
K = 100
for _ in range(10000):
y = torch.randint(K, (100,))
y_onehot = F.one_hot(y, K).float()
logits = torch.randn(100, K)
l1 = F.cross_entropy(logits, y)
l2 = soft_cross_entropy(logits, y_onehot)
print (l1.item(), l2.item(), '%.5f' %(l1-l2).abs().item())

28
models.py Normal file
Просмотреть файл

@ -0,0 +1,28 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResNet50(nn.Module):
def __init__(self, num_classes):
super().__init__()
import torchvision
import os
os.environ['TORCH_HOME'] = 'cache' # hacky workaround to set model dir
self.resnet50 = torchvision.models.resnet50(pretrained=True)
self.resnet50.fc = nn.Identity() # remote last fc
self.fc = nn.Linear(2048, num_classes)
self.init_weights()
def init_weights(self):
nn.init.xavier_normal_(self.fc.weight)
self.fc.bias.data.zero_()
def forward(self, x, return_h=False): # (bs, C, H, W)
pooled_output = self.resnet50(x)
logit = self.fc(pooled_output)
if return_h:
return logit, pooled_output
else:
return logit