initial commit for CIFAR
This commit is contained in:
Родитель
050736efe5
Коммит
1a68fd743c
|
@ -0,0 +1,153 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from utils import DataIterator
|
||||
|
||||
def prepare_data(gold_fraction, corruption_prob, corruption_type, args):
|
||||
if args.use_mwnet_loader:
|
||||
return prepare_data_mwnet(gold_fraction, corruption_prob, corruption_type, args)
|
||||
else:
|
||||
return prepare_data_mlc(gold_fraction, corruption_prob, corruption_type, args)
|
||||
|
||||
def prepare_data_mwnet(gold_fraction, corruption_prob, corruption_type, args):
|
||||
from load_corrupted_data_mlg import CIFAR10, CIFAR100
|
||||
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
|
||||
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
|
||||
if True: # no augment as used by mwnet
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
|
||||
(4, 4, 4, 4), mode='reflect').squeeze()),
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomCrop(32),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
else:
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
args.num_meta = int(50000 * gold_fraction)
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
num_classes = 10
|
||||
|
||||
train_data_meta = CIFAR10(
|
||||
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True)
|
||||
train_data = CIFAR10(
|
||||
root=args.data_path, train=True, meta=False, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
|
||||
test_data = CIFAR10(root=args.data_path, train=False, transform=test_transform, download=True)
|
||||
|
||||
valid_data = CIFAR10(
|
||||
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True)
|
||||
|
||||
elif args.dataset == 'cifar100':
|
||||
num_classes = 100
|
||||
|
||||
train_data_meta = CIFAR100(
|
||||
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True)
|
||||
train_data = CIFAR100(
|
||||
root=args.data_path, train=True, meta=False, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True, seed=args.seed)
|
||||
test_data = CIFAR100(root=args.data_path, train=False, transform=test_transform, download=True)
|
||||
|
||||
valid_data = CIFAR100(
|
||||
root=args.data_path, train=True, meta=True, num_meta=args.num_meta, corruption_prob=corruption_prob,
|
||||
corruption_type=args.corruption_type, transform=train_transform, download=True)
|
||||
|
||||
train_gold_loader = DataIterator(torch.utils.data.DataLoader(train_data_meta, batch_size=args.bs, shuffle=True,
|
||||
num_workers=args.prefetch, pin_memory=True))
|
||||
train_silver_loader = torch.utils.data.DataLoader(train_data, batch_size=args.bs, shuffle=True,
|
||||
num_workers=args.prefetch, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=args.bs, shuffle=True,
|
||||
num_workers=args.prefetch, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.bs, shuffle=False,
|
||||
num_workers=args.prefetch, pin_memory=True)
|
||||
|
||||
return train_gold_loader, train_silver_loader, valid_loader, test_loader, num_classes
|
||||
|
||||
def prepare_data_mlc(gold_fraction, corruption_prob, corruption_type, args):
|
||||
from load_corrupted_data import CIFAR10, CIFAR100
|
||||
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
|
||||
train_transform = transforms.Compose(
|
||||
[transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)])
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
|
||||
# since cifar10 and cifar100 have no official validation split, use gold as valid also
|
||||
if args.dataset == 'cifar10':
|
||||
train_data_gold = CIFAR10(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
train_data_silver = CIFAR10(
|
||||
args.data_path, True, False, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, seed=args.seed, distinguish_gold=False, weaklabel=args.weaklabel) # note here for the change
|
||||
train_data_gold_deterministic = CIFAR10(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=test_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, distinguish_gold=False, seed=args.seed)
|
||||
test_data = CIFAR10(args.data_path, train=False, transform=test_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
|
||||
# same as gold
|
||||
valid_data = CIFAR10(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
|
||||
num_classes = 10
|
||||
|
||||
elif args.dataset == 'cifar100':
|
||||
train_data_gold = CIFAR100(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
train_data_silver = CIFAR100(
|
||||
args.data_path, True, False, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, seed=args.seed, distinguish_gold=False,
|
||||
weaklabel=args.weaklabel) # note the weaklabel arg
|
||||
train_data_gold_deterministic = CIFAR100(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=test_transform, download=True, shuffle_indices=train_data_gold.shuffle_indices, distinguish_gold=False, seed=args.seed)
|
||||
test_data = CIFAR100(args.data_path, train=False, transform=test_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
|
||||
# same as gold
|
||||
valid_data = CIFAR100(
|
||||
args.data_path, True, True, gold_fraction, corruption_prob, args.corruption_type,
|
||||
transform=train_transform, download=True, distinguish_gold=False, seed=args.seed)
|
||||
|
||||
num_classes = 100
|
||||
|
||||
|
||||
gold_sampler = None
|
||||
silver_sampler = None
|
||||
valid_sampler = None
|
||||
test_sampler = None
|
||||
batch_size = args.bs
|
||||
|
||||
train_gold_loader = DataIterator(torch.utils.data.DataLoader(
|
||||
train_data_gold, batch_size=batch_size, shuffle=(gold_sampler is None),
|
||||
num_workers=args.prefetch, pin_memory=True, sampler=gold_sampler))
|
||||
train_silver_loader =torch.utils.data.DataLoader(
|
||||
train_data_silver, batch_size=batch_size, shuffle=(silver_sampler is None),
|
||||
num_workers=args.prefetch, pin_memory=True, sampler=silver_sampler)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=(valid_sampler is None), num_workers=args.prefetch, pin_memory=True, sampler=valid_sampler)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=(test_sampler is None), num_workers=args.prefetch, pin_memory=True, sampler=test_sampler)
|
||||
|
||||
return train_gold_loader, train_silver_loader, valid_loader, test_loader, num_classes
|
|
@ -0,0 +1,162 @@
|
|||
'''
|
||||
Credit to https://github.com/akamaster
|
||||
|
||||
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
|
||||
|
||||
The implementation and structure of this file is hugely influenced by [2]
|
||||
which is implemented for ImageNet and doesn't have option A for identity.
|
||||
Moreover, most of the implementations on the web is copy-paste from
|
||||
torchvision's resnet and has wrong number of params.
|
||||
|
||||
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
|
||||
number of layers and parameters:
|
||||
|
||||
name | layers | params
|
||||
ResNet20 | 20 | 0.27M
|
||||
ResNet32 | 32 | 0.46M
|
||||
ResNet44 | 44 | 0.66M
|
||||
ResNet56 | 56 | 0.85M
|
||||
ResNet110 | 110 | 1.7M
|
||||
ResNet1202| 1202 | 19.4m
|
||||
|
||||
which this implementation indeed has.
|
||||
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||
|
||||
If you use this implementation in you work, please don't forget to mention the
|
||||
author, Yerlan Idelbayev.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
|
||||
|
||||
def _weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
#print(classname)
|
||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight)
|
||||
|
||||
class LambdaLayer(nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super(LambdaLayer, self).__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, x):
|
||||
return self.lambd(x)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, option='A'):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
if option == 'A':
|
||||
"""
|
||||
For CIFAR10 ResNet paper uses option A.
|
||||
"""
|
||||
self.shortcut = LambdaLayer(lambda x:
|
||||
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
|
||||
elif option == 'B':
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 16
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
|
||||
self.linear = nn.Linear(64, num_classes)
|
||||
|
||||
self.apply(_weights_init)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, return_h=False):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.avg_pool2d(out, out.size()[3])
|
||||
hidden = out.view(out.size(0), -1)
|
||||
out = self.linear(hidden)
|
||||
if return_h:
|
||||
return out, hidden
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def resnet20():
|
||||
return ResNet(BasicBlock, [3, 3, 3])
|
||||
|
||||
|
||||
def resnet32(num_classes=10):
|
||||
return ResNet(BasicBlock, [5, 5, 5], num_classes)
|
||||
|
||||
|
||||
def resnet44():
|
||||
return ResNet(BasicBlock, [7, 7, 7])
|
||||
|
||||
|
||||
def resnet56():
|
||||
return ResNet(BasicBlock, [9, 9, 9])
|
||||
|
||||
|
||||
def resnet110():
|
||||
return ResNet(BasicBlock, [18, 18, 18])
|
||||
|
||||
|
||||
def resnet1202():
|
||||
return ResNet(BasicBlock, [200, 200, 200])
|
||||
|
||||
|
||||
def test(net):
|
||||
import numpy as np
|
||||
total_params = 0
|
||||
|
||||
for x in filter(lambda p: p.requires_grad, net.parameters()):
|
||||
total_params += np.prod(x.data.numpy().shape)
|
||||
print("Total number of params", total_params)
|
||||
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for net_name in __all__:
|
||||
if net_name.startswith('resnet'):
|
||||
print(net_name)
|
||||
test(globals()[net_name]())
|
||||
print()
|
|
@ -0,0 +1,26 @@
|
|||
import logging
|
||||
|
||||
def get_logger(filename, local_rank):
|
||||
formatter = logging.Formatter(fmt='[%(asctime)s %(levelname)s] %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.handlers = []
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
if filename is not None and local_rank <=0: # only log to file for first GPU
|
||||
f_handler = logging.FileHandler(filename, 'a')
|
||||
f_handler.setLevel(logging.INFO)
|
||||
f_handler.setFormatter(formatter)
|
||||
logger.addHandler(f_handler)
|
||||
|
||||
stdout_handler = logging.StreamHandler()
|
||||
stdout_handler.setFormatter(formatter)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(stdout_handler)
|
||||
else: # null handlers for other GPUs
|
||||
null_handler = logging.NullHandler()
|
||||
null_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(null_handler)
|
||||
|
||||
return logger
|
|
@ -0,0 +1,425 @@
|
|||
import numpy as np
|
||||
import pickle
|
||||
import copy
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from logger import get_logger
|
||||
from tqdm import tqdm
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from mlc import step_hmlc_K
|
||||
from mlc_utils import clone_parameters, tocuda, DummyScheduler
|
||||
|
||||
from models import *
|
||||
from meta_models import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='MLC Training Framework')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'clothing1m'], default='cifar10')
|
||||
parser.add_argument('--method', default='hmlc_K_mix', type=str, choices=['hmlc_K_mix', 'hmlc_K'])
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--data_seed', type=int, default=1)
|
||||
parser.add_argument('--epochs', '-e', type=int, default=75, help='Number of epochs to train.')
|
||||
parser.add_argument('--num_iterations', default=100000, type=int)
|
||||
parser.add_argument('--every', default=100, type=int, help='Eval interval (default: 100 iters)')
|
||||
parser.add_argument('--bs', default=32, type=int, help='batch size')
|
||||
parser.add_argument('--test_bs', default=100, type=int, help='batch size')
|
||||
parser.add_argument('--gold_bs', type=int, default=32)
|
||||
parser.add_argument('--cls_dim', type=int, default=64, help='Label embedding dim (Default: 64)')
|
||||
parser.add_argument('--grad_clip', default=0.0, type=float, help='max grad norm (default: 0, no clip)')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer')
|
||||
parser.add_argument('--main_lr', default=3e-4, type=float, help='lr for main net')
|
||||
parser.add_argument('--meta_lr', default=3e-5, type=float, help='lr for meta net')
|
||||
parser.add_argument('--optimizer', default='adam', type=str, choices=['adam', 'sgd', 'adadelta'])
|
||||
parser.add_argument('--opt_eps', default=1e-8, type=float, help='eps for optimizers')
|
||||
#parser.add_argument('--tau', default=1, type=float, help='tau')
|
||||
parser.add_argument('--wdecay', default=5e-4, type=float, help='weight decay (default: 5e-4)')
|
||||
|
||||
# noise parameters
|
||||
parser.add_argument('--corruption_type', default='unif', type=str, choices=['unif', 'flip'])
|
||||
parser.add_argument('--corruption_level', default='-1', type=float, help='Corruption level')
|
||||
parser.add_argument('--gold_fraction', default='-1', type=float, help='Gold fraction')
|
||||
|
||||
parser.add_argument('--sparsemax', default=False, action='store_true', help='Use softmax instead of softmax for meta model (default: False)')
|
||||
parser.add_argument('--tie', default=False, action='store_true', help='Tie label embedding to the output classifier output embedding of metanet (default: False)')
|
||||
|
||||
parser.add_argument('--runid', default='exp', type=str)
|
||||
parser.add_argument('--queue_size', default=1, type=int, help='Number of iterations before to compute mean loss_g')
|
||||
|
||||
############## LOOK-AHEAD GRADIENT STEPS FOR MLC ##################
|
||||
parser.add_argument('--gradient_steps', default=1, type=int, help='Number of look-ahead gradient steps for meta-gradient (default: 1)')
|
||||
|
||||
# CIFAR
|
||||
# Positional arguments
|
||||
parser.add_argument('--data_path', default='data', type=str, help='Root for the datasets.')
|
||||
# Optimization options
|
||||
parser.add_argument('--nosgdr', default=False, action='store_true', help='Turn off SGDR.')
|
||||
|
||||
# Acceleration
|
||||
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
|
||||
# i/o
|
||||
parser.add_argument('--logdir', type=str, default='runs', help='Log folder.')
|
||||
parser.add_argument('--local_rank', type=int, default=-1, help='local rank (-1 for local training)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# //////////////////////////// set logging ///////////////////////
|
||||
filename = '_'.join([args.dataset, args.method, args.corruption_type, args.runid, str(args.epochs), str(args.seed), str(args.data_seed)])
|
||||
logfile = 'logs/' + filename + '.log'
|
||||
logger = get_logger(logfile, args.local_rank)
|
||||
# //////////////////////////////////////////////////////////////////
|
||||
|
||||
logger.info(args)
|
||||
logger.info('CUDA available:' + str(torch.cuda.is_available()))
|
||||
|
||||
# cuda set up
|
||||
torch.cuda.set_device(0) # local GPU
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# loss function for hard label and soft labels
|
||||
hard_loss_f = F.cross_entropy
|
||||
from mlc_utils import soft_cross_entropy as soft_loss_f
|
||||
|
||||
# //////////////////////// defining model ////////////////////////
|
||||
|
||||
def get_data(dataset, gold_fraction, corruption_prob, get_C):
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
sys.path.append('CIFAR')
|
||||
|
||||
from data_helper_cifar import prepare_data
|
||||
args.use_mwnet_loader = True # use exactly the same loader as in the mwnet paper
|
||||
logger.info('================= Use the same dataloader as in MW-Net =========================')
|
||||
return prepare_data(gold_fraction, corruption_prob, get_C, args)
|
||||
elif dataset == 'clothing1m':
|
||||
sys.path.append('CLOTHING1M')
|
||||
|
||||
from data_helper_clothing1m import prepare_data
|
||||
return prepare_data(args)
|
||||
|
||||
def build_models(dataset, num_classes):
|
||||
cls_dim = args.cls_dim # input label embedding dimension
|
||||
|
||||
if dataset in ['cifar10', 'cifar100']:
|
||||
from CIFAR.resnet import resnet32
|
||||
|
||||
# main net
|
||||
model = resnet32(num_classes)
|
||||
main_net = model
|
||||
|
||||
# meta net
|
||||
hx_dim = 64 #0 if isinstance(model, WideResNet) else 64 # 64 for resnet-32
|
||||
meta_net = MetaNet(hx_dim, cls_dim, 128, num_classes, args)
|
||||
|
||||
elif dataset == 'clothing1m': # use pretrained ResNet-50 model
|
||||
model = ResNet50(num_classes)
|
||||
main_net = model
|
||||
|
||||
hx_dim = 2048 # from resnet50
|
||||
meta_net = MetaNet(2048, cls_dim, 128, num_classes, args)
|
||||
|
||||
main_net = main_net.cuda()
|
||||
meta_net = meta_net.cuda()
|
||||
|
||||
logger.info('========== Main model ==========')
|
||||
logger.info(model)
|
||||
logger.info('========== Meta model ==========')
|
||||
logger.info(meta_net)
|
||||
|
||||
return main_net, meta_net
|
||||
|
||||
def setup_training(main_net, meta_net, exp_id=None):
|
||||
|
||||
# ============== setting up from scratch ===================
|
||||
# set up optimizers and schedulers
|
||||
# meta net optimizer
|
||||
optimizer = torch.optim.Adam(meta_net.parameters(), lr=args.meta_lr,
|
||||
weight_decay=0, #args.wdecay, # meta should have wdecay or not??
|
||||
amsgrad=True, eps=args.opt_eps)
|
||||
scheduler = DummyScheduler(optimizer)
|
||||
|
||||
# main net optimizer
|
||||
main_params = main_net.parameters()
|
||||
|
||||
if args.optimizer == 'adam':
|
||||
main_opt = torch.optim.Adam(main_params, lr=args.main_lr, weight_decay=args.wdecay, amsgrad=True, eps=args.opt_eps)
|
||||
elif args.optimizer == 'sgd':
|
||||
main_opt = torch.optim.SGD(main_params, lr=args.main_lr, weight_decay=args.wdecay, momentum=args.momentum)
|
||||
|
||||
if args.dataset in ['cifar10', 'cifar100']:
|
||||
# follow MW-Net setting
|
||||
main_schdlr = torch.optim.lr_scheduler.MultiStepLR(main_opt, milestones=[80,100], gamma=0.1)
|
||||
elif args.dataset in ['clothing1m']:
|
||||
main_schdlr = torch.optim.lr_scheduler.MultiStepLR(main_opt, milestones=[5], gamma=0.1)
|
||||
else:
|
||||
main_schdlr = DummyScheduler(main_opt)
|
||||
|
||||
last_epoch = -1
|
||||
|
||||
return main_net, meta_net, main_opt, optimizer, main_schdlr, scheduler, last_epoch
|
||||
|
||||
def uniform_mix_C(num_classes, mixing_ratio):
|
||||
'''
|
||||
returns a linear interpolation of a uniform matrix and an identity matrix
|
||||
'''
|
||||
return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
|
||||
(1 - mixing_ratio) * np.eye(num_classes)
|
||||
|
||||
def flip_labels_C(num_classes, corruption_prob):
|
||||
'''
|
||||
returns a matrix with (1 - corruption_prob) on the diagonals, and corruption_prob
|
||||
concentrated in only one other entry for each row
|
||||
'''
|
||||
np.random.seed(args.seed)
|
||||
|
||||
C = np.eye(num_classes) * (1 - corruption_prob)
|
||||
row_indices = np.arange(num_classes)
|
||||
for i in range(num_classes):
|
||||
C[i][np.random.choice(row_indices[row_indices != i])] = corruption_prob
|
||||
return C
|
||||
|
||||
|
||||
# //////////////////////// run experiments ////////////////////////
|
||||
def run():
|
||||
corruption_fnctn = uniform_mix_C if args.corruption_type == 'unif' else flip_labels_C
|
||||
filename = '_'.join([args.dataset, args.method, args.corruption_type, args.runid, str(args.epochs), str(args.seed), str(args.data_seed)])
|
||||
|
||||
results = {}
|
||||
|
||||
# 100 labels per class
|
||||
gf_dict = {'yelp2': 200.0 / 560000,
|
||||
'yelp5': 500.0 / 650000,
|
||||
'amazon2': 200.0 / 3600000,
|
||||
'amazon5': 500.0 / 3000000,
|
||||
'dbpedia': 1400.0 / 560000,
|
||||
'yahoo': 700.0 / 1400000,
|
||||
'imdb2': 200.0 / 25000,
|
||||
'ag': 400.0 / 120000,
|
||||
}
|
||||
|
||||
# revisit this
|
||||
gold_fractions = [0.05, 0.001, 0.01]
|
||||
|
||||
if args.gold_fraction != -1:
|
||||
assert args.gold_fraction >=0 and args.gold_fraction <=1, 'Wrong gold fraction!'
|
||||
gold_fractions = [args.gold_fraction]
|
||||
|
||||
corruption_levels = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
if args.corruption_level != -1: # specied one corruption_level
|
||||
assert args.corruption_level >= 0 and args.corruption_level <=1, 'Wrong noise level!'
|
||||
corruption_levels = [args.corruption_level]
|
||||
|
||||
for gold_fraction in gold_fractions:
|
||||
results[gold_fraction] = {}
|
||||
for corruption_level in corruption_levels:
|
||||
# //////////////////////// load data //////////////////////////////
|
||||
# use data_seed her
|
||||
gold_loader, silver_loader, valid_loader, test_loader, num_classes = get_data(args.dataset, gold_fraction, corruption_level, corruption_fnctn)
|
||||
|
||||
# //////////////////////// build main_net and meta_net/////////////
|
||||
main_net, meta_net = build_models(args.dataset, num_classes)
|
||||
|
||||
# //////////////////////// train and eval model ///////////////////
|
||||
exp_id = '_'.join([filename, str(gold_fraction), str(corruption_level)])
|
||||
test_acc, baseline_acc = train_and_test(main_net, meta_net, gold_loader, silver_loader, valid_loader, test_loader, exp_id)
|
||||
|
||||
results[gold_fraction][corruption_level] = {}
|
||||
results[gold_fraction][corruption_level]['method'] = test_acc
|
||||
results[gold_fraction][corruption_level]['baseline'] = baseline_acc
|
||||
logger.info(' '.join(['Gold fraction:', str(gold_fraction), '| Corruption level:', str(corruption_level),
|
||||
'| Method acc:', str(results[gold_fraction][corruption_level]['method']),
|
||||
'| Baseline acc:', str(results[gold_fraction][corruption_level]['baseline'])]))
|
||||
logger.info('')
|
||||
|
||||
|
||||
with open('out/' + filename, 'wb') as file:
|
||||
pickle.dump(results, file)
|
||||
logger.info("Dumped results_ours in file: " + filename)
|
||||
|
||||
def test(main_net, test_loader): # this could be eval or test
|
||||
# //////////////////////// evaluate method ////////////////////////
|
||||
correct = torch.zeros(1).cuda()
|
||||
nsamples = torch.zeros(1).cuda()
|
||||
|
||||
# forward
|
||||
main_net.eval()
|
||||
|
||||
for idx, (*data, target) in enumerate(test_loader):
|
||||
data, target = tocuda(data), tocuda(target)
|
||||
|
||||
# forward
|
||||
with torch.no_grad():
|
||||
output = main_net(data)
|
||||
|
||||
# accuracy
|
||||
pred = output.data.max(1)[1]
|
||||
correct += pred.eq(target.data).sum().item()
|
||||
nsamples += len(target)
|
||||
|
||||
test_acc = (correct / nsamples).item()
|
||||
|
||||
# set back to train
|
||||
main_net.train()
|
||||
|
||||
return test_acc
|
||||
|
||||
|
||||
####################################################################################################
|
||||
### training code
|
||||
####################################################################################################
|
||||
def train_and_test(main_net, meta_net, gold_loader, silver_loader, valid_loader, test_loader, exp_id=None):
|
||||
writer = SummaryWriter(args.logdir + '/' + exp_id)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
main_net, meta_net, main_opt, optimizer, main_schdlr, scheduler, last_epoch = setup_training(main_net, meta_net, exp_id)
|
||||
|
||||
# //////////////////////// switching on training mode ////////////////////////
|
||||
meta_net.train()
|
||||
main_net.train()
|
||||
|
||||
# set up statistics
|
||||
best_params = None
|
||||
best_main_opt_sd = None
|
||||
best_main_schdlr_sd = None
|
||||
|
||||
best_meta_params = None
|
||||
best_meta_opt_sd = None
|
||||
best_meta_schdlr_sd = None
|
||||
best_val_metric = float('inf')
|
||||
|
||||
val_metric_queue = deque()
|
||||
# set done
|
||||
|
||||
args.dw_prev = [0 for param in meta_net.parameters()] # 0 for previous iteration
|
||||
args.steps = 0
|
||||
|
||||
for epoch in tqdm(range(last_epoch+1, args.epochs)):# change to epoch iteration
|
||||
logger.info('Epoch %d:' % epoch)
|
||||
|
||||
for i, (*data_s, target_s) in enumerate(silver_loader):
|
||||
*data_g, target_g = next(gold_loader)#.next()
|
||||
|
||||
data_g, target_g = tocuda(data_g), tocuda(target_g)
|
||||
data_s, target_s_ = tocuda(data_s), tocuda(target_s)
|
||||
|
||||
# bi-level optimization stage
|
||||
eta = main_schdlr.get_lr()[0]
|
||||
if args.method == 'hmlc_K':
|
||||
loss_g, loss_s = step_hmlc_K(main_net, main_opt, hard_loss_f,
|
||||
meta_net, optimizer, soft_loss_f,
|
||||
data_s, target_s_, data_g, target_g,
|
||||
None, None,
|
||||
eta, args)
|
||||
elif args.method == 'hmlc_K_mix':
|
||||
# split the clean set to two, one for training and the other for meta-evaluation
|
||||
gbs = int(target_g.size(0) / 2)
|
||||
if type(data_g) is list:
|
||||
data_c = [x[gbs:] for x in data_g]
|
||||
data_g = [x[:gbs] for x in data_g]
|
||||
else:
|
||||
data_c = data_g[gbs:]
|
||||
data_g = data_g[:gbs]
|
||||
|
||||
target_c = target_g[gbs:]
|
||||
target_g = target_g[:gbs]
|
||||
loss_g, loss_s = step_hmlc_K(main_net, main_opt, hard_loss_f,
|
||||
meta_net, optimizer, soft_loss_f,
|
||||
data_s, target_s_, data_g, target_g,
|
||||
data_c, target_c,
|
||||
eta, args)
|
||||
|
||||
args.steps += 1
|
||||
if i % args.every == 0:
|
||||
writer.add_scalar('train/loss_g', loss_g.item(), args.steps)
|
||||
writer.add_scalar('train/loss_s', loss_s.item(), args.steps)
|
||||
|
||||
''' get entropy of predictions from meta-net '''
|
||||
logit_s, x_s_h = main_net(data_s, return_h=True)
|
||||
pseudo_target_s = meta_net(x_s_h.detach(), target_s_).detach()
|
||||
entropy = -(pseudo_target_s * torch.log(pseudo_target_s+1e-10)).sum(-1).mean()
|
||||
|
||||
writer.add_scalar('train/meta_entropy', entropy.item(), args.steps)
|
||||
|
||||
main_lr = main_schdlr.get_lr()[0]
|
||||
meta_lr = scheduler.get_lr()[0]
|
||||
writer.add_scalar('train/main_lr', main_lr, args.steps)
|
||||
writer.add_scalar('train/meta_lr', meta_lr, args.steps)
|
||||
writer.add_scalar('train/gradient_steps', args.gradient_steps, args.steps)
|
||||
|
||||
logger.info('Iteration %d loss_s: %.4f\tloss_g: %.4f\tMeta entropy: %.3f\tMain LR: %.8f\tMeta LR: %.8f' %( i, loss_s.item(), loss_g.item(), entropy.item(), main_lr, meta_lr))
|
||||
|
||||
# PER EPOCH PROCESSING
|
||||
|
||||
# lr scheduler
|
||||
main_schdlr.step()
|
||||
#scheduler.step()
|
||||
|
||||
# evaluation on validation set
|
||||
val_acc = test(main_net, valid_loader)
|
||||
test_acc = test(main_net, test_loader)
|
||||
|
||||
logger.info('Val acc: %.4f\tTest acc: %.4f' % (val_acc, test_acc))
|
||||
if args.local_rank <=0: # single GPU or GPU 0
|
||||
writer.add_scalar('train/val_acc', val_acc, epoch)
|
||||
writer.add_scalar('test/test_acc', test_acc, epoch)
|
||||
|
||||
if len(val_metric_queue) == args.queue_size: # keep at most this number of records
|
||||
# remove the oldest record
|
||||
val_metric_queue.popleft()
|
||||
|
||||
val_metric_queue.append(-val_acc)
|
||||
|
||||
avg_val_metric = sum(list(val_metric_queue)) / len(val_metric_queue)
|
||||
if avg_val_metric < best_val_metric:
|
||||
best_val_metric = avg_val_metric
|
||||
|
||||
best_params = copy.deepcopy(main_net.state_dict())
|
||||
|
||||
best_main_opt_sd = copy.deepcopy(main_opt.state_dict())
|
||||
best_main_schdlr_sd = copy.deepcopy(main_schdlr.state_dict())
|
||||
|
||||
best_meta_params = copy.deepcopy(meta_net.state_dict())
|
||||
best_meta_opt_sd = copy.deepcopy(optimizer.state_dict())
|
||||
best_meta_schdlr_sd = copy.deepcopy(scheduler.state_dict())
|
||||
|
||||
# dump best to file also
|
||||
####################### save best models so far ###################
|
||||
|
||||
logger.info('Saving best models...')
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'val_metric': best_val_metric,
|
||||
'main_net': best_params,
|
||||
'main_opt': best_main_opt_sd,
|
||||
'main_schdlr': best_main_schdlr_sd,
|
||||
'meta_net': best_meta_params,
|
||||
'meta_opt': best_meta_opt_sd,
|
||||
'meta_schdlr': best_meta_schdlr_sd
|
||||
}, 'models/%s_best.pth' % exp_id)
|
||||
|
||||
|
||||
writer.add_scalar('train/val_acc_best', -best_val_metric, epoch) # write current best val_acc to tensorboard
|
||||
|
||||
# //////////////////////// evaluating method ////////////////////////
|
||||
main_net.load_state_dict(best_params)
|
||||
test_acc = test(main_net, test_loader) # evaluate best params picked from validation
|
||||
|
||||
writer.add_scalar('test/acc', test_acc, args.steps) # this test_acc should be roughly the best as it's taken from the best iteration
|
||||
logger.info('Test acc: %.4f' % test_acc)
|
||||
|
||||
return test_acc, 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
|
@ -0,0 +1,67 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
def __init__(self, hx_dim, cls_dim, h_dim, num_classes, args):
|
||||
super().__init__()
|
||||
|
||||
self.args = args
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_class = self.num_classes
|
||||
self.hdim = h_dim
|
||||
self.cls_emb = nn.Embedding(self.in_class, cls_dim)
|
||||
|
||||
in_dim = hx_dim + cls_dim
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(in_dim, self.hdim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hdim, self.hdim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hdim, num_classes, bias=(not self.args.tie))
|
||||
)
|
||||
|
||||
if self.args.sparsemax:
|
||||
from sparsemax import Sparsemax
|
||||
self.sparsemax = Sparsemax(-1)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
if self.args.tie:
|
||||
print ('Tying cls emb to output cls weight')
|
||||
self.net[-1].weight = self.cls_emb.weight
|
||||
|
||||
def init_weights(self):
|
||||
nn.init.xavier_uniform_(self.cls_emb.weight)
|
||||
nn.init.xavier_normal_(self.net[0].weight)
|
||||
nn.init.xavier_normal_(self.net[2].weight)
|
||||
nn.init.xavier_normal_(self.net[4].weight)
|
||||
|
||||
self.net[0].bias.data.zero_()
|
||||
self.net[2].bias.data.zero_()
|
||||
|
||||
if not self.args.tie:
|
||||
assert self.in_class == self.num_classes, 'In and out classes conflict!'
|
||||
self.net[4].bias.data.zero_()
|
||||
|
||||
def get_alpha(self):
|
||||
return torch.zeros(1)
|
||||
|
||||
def forward(self, hx, y):
|
||||
bs = hx.size(0)
|
||||
|
||||
y_emb = self.cls_emb(y)
|
||||
hin = torch.cat([hx, y_emb], dim=-1)
|
||||
|
||||
logit = self.net(hin)
|
||||
|
||||
if self.args.sparsemax:
|
||||
out = self.sparsemax(logit) # test sparsemax
|
||||
else:
|
||||
out = F.softmax(logit, -1)
|
||||
|
||||
return out
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def _concat(xs):
|
||||
return torch.cat([x.view(-1) for x in xs])
|
||||
|
||||
@torch.no_grad()
|
||||
def update_params(params, grads, eta, opt, args, deltaonly=False, return_s=False):
|
||||
if isinstance(opt, torch.optim.SGD):
|
||||
return update_params_sgd(params, grads, eta, opt, args, deltaonly, return_s)
|
||||
else:
|
||||
raise NotImplementedError('Non-supported main model optimizer type!')
|
||||
|
||||
# be aware that the opt state dict returns references, hence take care not to
|
||||
# modify them
|
||||
def update_params_sgd(params, grads, eta, opt, args, deltaonly, return_s=False):
|
||||
# supports SGD-like optimizers
|
||||
ans = []
|
||||
|
||||
if return_s:
|
||||
ss = []
|
||||
|
||||
wdecay = opt.defaults['weight_decay']
|
||||
momentum = opt.defaults['momentum']
|
||||
dampening = opt.defaults['dampening']
|
||||
nesterov = opt.defaults['nesterov']
|
||||
|
||||
for i, param in enumerate(params):
|
||||
dparam = grads[i] + param * wdecay # s=1
|
||||
s = 1
|
||||
|
||||
if momentum > 0:
|
||||
try:
|
||||
moment = opt.state[param]['momentum_buffer'] * momentum
|
||||
except:
|
||||
moment = torch.zeros_like(param)
|
||||
|
||||
moment.add_(dparam, alpha=1. -dampening) # s=1.-dampening
|
||||
|
||||
if nesterov:
|
||||
dparam = dparam + momentum * moment # s= 1+momentum*(1.-dampening)
|
||||
s = 1 + momentum*(1.-dampening)
|
||||
else:
|
||||
dparam = moment # s=1.-dampening
|
||||
s = 1.-dampening
|
||||
|
||||
if deltaonly:
|
||||
ans.append(- dparam * eta)
|
||||
else:
|
||||
ans.append(param - dparam * eta)
|
||||
|
||||
if return_s:
|
||||
ss.append(s*eta)
|
||||
|
||||
if return_s:
|
||||
return ans, ss
|
||||
else:
|
||||
return ans
|
||||
|
||||
|
||||
# ============== mlc step procedure debug with features (gradient-stopped) from main model ===========
|
||||
#
|
||||
# METANET uses the last K-1 steps from main model and imagine one additional step ahead
|
||||
# to compose a pool of actual K steps from the main model
|
||||
#
|
||||
#
|
||||
def step_hmlc_K(main_net, main_opt, hard_loss_f,
|
||||
meta_net, meta_opt, soft_loss_f,
|
||||
data_s, target_s, data_g, target_g,
|
||||
data_c, target_c,
|
||||
eta, args):
|
||||
|
||||
# compute gw for updating meta_net
|
||||
logit_g = main_net(data_g)
|
||||
loss_g = hard_loss_f(logit_g, target_g)
|
||||
gw = torch.autograd.grad(loss_g, main_net.parameters())
|
||||
|
||||
# given current meta net, get corrected label
|
||||
logit_s, x_s_h = main_net(data_s, return_h=True)
|
||||
pseudo_target_s = meta_net(x_s_h.detach(), target_s)
|
||||
loss_s = soft_loss_f(logit_s, pseudo_target_s)
|
||||
|
||||
if data_c is not None:
|
||||
bs1 = target_s.size(0)
|
||||
bs2 = target_c.size(0)
|
||||
|
||||
logit_c = main_net(data_c)
|
||||
loss_s2 = hard_loss_f(logit_c, target_c)
|
||||
loss_s = (loss_s * bs1 + loss_s2 * bs2 ) / (bs1+bs2)
|
||||
|
||||
f_param_grads = torch.autograd.grad(loss_s, main_net.parameters(), create_graph=True)
|
||||
|
||||
f_params_new, dparam_s = update_params(main_net.parameters(), f_param_grads, eta, main_opt, args, return_s=True)
|
||||
# 2. set w as w'
|
||||
f_param = []
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
f_param.append(param.data.clone())
|
||||
param.data = f_params_new[i].data # use data only as f_params_new has graph
|
||||
|
||||
# training loss Hessian approximation
|
||||
Hw = 1 # assume to be identity
|
||||
|
||||
# 3. compute d_w' L_{D}(w')
|
||||
logit_g = main_net(data_g)
|
||||
loss_g = hard_loss_f(logit_g, target_g)
|
||||
gw_prime = torch.autograd.grad(loss_g, main_net.parameters())
|
||||
|
||||
# 3.5 compute discount factor gw_prime * (I-LH) * gw.t() / |gw|^2
|
||||
tmp1 = [(1-Hw*dparam_s[i]) * gw_prime[i] for i in range(len(dparam_s))]
|
||||
gw_norm2 = (_concat(gw).norm())**2
|
||||
tmp2 = [gw[i]/gw_norm2 for i in range(len(gw))]
|
||||
gamma = torch.dot(_concat(tmp1), _concat(tmp2))
|
||||
|
||||
# because of dparam_s, need to scale up/down f_params_grads_prime for proxy_g/loss_g
|
||||
Lgw_prime = [ dparam_s[i] * gw_prime[i] for i in range(len(dparam_s))]
|
||||
|
||||
proxy_g = -torch.dot(_concat(f_param_grads), _concat(Lgw_prime))
|
||||
|
||||
# back prop on alphas
|
||||
meta_opt.zero_grad()
|
||||
proxy_g.backward()
|
||||
|
||||
# accumulate discounted iterative gradient
|
||||
for i, param in enumerate(meta_net.parameters()):
|
||||
if param.grad is not None:
|
||||
param.grad.add_(gamma * args.dw_prev[i])
|
||||
args.dw_prev[i] = param.grad.clone()
|
||||
|
||||
if (args.steps+1) % (args.gradient_steps)==0: # T steps proceeded by main_net
|
||||
meta_opt.step()
|
||||
args.dw_prev = [0 for param in meta_net.parameters()] # 0 to reset
|
||||
|
||||
# modify to w, and then do actual update main_net
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
param.data = f_param[i]
|
||||
param.grad = f_param_grads[i].data
|
||||
main_opt.step()
|
||||
|
||||
return loss_g, loss_s
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def save_checkpoint(state, filename):
|
||||
torch.save(state, filename)
|
||||
|
||||
class DummyScheduler(torch.optim.lr_scheduler._LRScheduler):
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
lrs.append(param_group['lr'])
|
||||
|
||||
return lrs
|
||||
|
||||
def step(self, epoch=None):
|
||||
pass
|
||||
|
||||
def tocuda(data):
|
||||
if type(data) is list:
|
||||
if len(data) == 1:
|
||||
return data[0].cuda()
|
||||
else:
|
||||
return [x.cuda() for x in data]
|
||||
else:
|
||||
return data.cuda()
|
||||
'''
|
||||
def net_grad_norm_max(model, p):
|
||||
grad_norms = [x.grad.data.norm(p).item() for x in model.parameters()]
|
||||
return max(grad_norms)
|
||||
'''
|
||||
|
||||
def clone_parameters(model):
|
||||
assert isinstance(model, torch.nn.Module), 'Wrong model type'
|
||||
|
||||
params = model.named_parameters()
|
||||
|
||||
f_params_dict = {}
|
||||
f_params = []
|
||||
for idx, (name, param) in enumerate(params):
|
||||
new_param = torch.nn.Parameter(param.data.clone())
|
||||
f_params_dict[name] = idx
|
||||
f_params.append(new_param)
|
||||
|
||||
return f_params, f_params_dict
|
||||
|
||||
# target differentiable version of F.cross_entropy
|
||||
def soft_cross_entropy(logit, pseudo_target, reduction='mean'):
|
||||
loss = -(pseudo_target * F.log_softmax(logit, -1)).sum(-1)
|
||||
if reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif reduction == 'none':
|
||||
return loss
|
||||
elif reduction == 'sum':
|
||||
return loss.sum()
|
||||
else:
|
||||
raise NotImplementedError('Invalid reduction: %s' % reduction)
|
||||
|
||||
|
||||
# test code for soft_cross_entropy
|
||||
if __name__ == '__main__':
|
||||
K = 100
|
||||
for _ in range(10000):
|
||||
y = torch.randint(K, (100,))
|
||||
y_onehot = F.one_hot(y, K).float()
|
||||
logits = torch.randn(100, K)
|
||||
|
||||
l1 = F.cross_entropy(logits, y)
|
||||
l2 = soft_cross_entropy(logits, y_onehot)
|
||||
|
||||
print (l1.item(), l2.item(), '%.5f' %(l1-l2).abs().item())
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ResNet50(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
import torchvision
|
||||
import os
|
||||
|
||||
os.environ['TORCH_HOME'] = 'cache' # hacky workaround to set model dir
|
||||
self.resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
self.resnet50.fc = nn.Identity() # remote last fc
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
nn.init.xavier_normal_(self.fc.weight)
|
||||
self.fc.bias.data.zero_()
|
||||
|
||||
def forward(self, x, return_h=False): # (bs, C, H, W)
|
||||
pooled_output = self.resnet50(x)
|
||||
logit = self.fc(pooled_output)
|
||||
if return_h:
|
||||
return logit, pooled_output
|
||||
else:
|
||||
return logit
|
Загрузка…
Ссылка в новой задаче