Adding Detcon B loss and partial implementation

This commit is contained in:
Anthony Ortiz 2021-10-21 06:12:32 +00:00
Родитель 059fa85b45
Коммит 8251272044
5 изменённых файлов: 342 добавлений и 0 удалений

14
conf/detcon.yaml Normal file
Просмотреть файл

@ -0,0 +1,14 @@
experiment:
task: "ssl"
name: "test_detcon"
module:
model: "detcon"
encoder: "resnet18"
input_channels: 3
training_mode: "self-supervised"
imagenet_pretraining: True
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 64
num_workers: 6

Просмотреть файл

@ -0,0 +1,14 @@
experiment:
task: "ssl"
name: "test_detcon"
module:
model: "detcon"
encoder: "resnet18"
input_channels: 3
training_mode: "self-supervised"
imagenet_pretraining: True
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 64
num_workers: 6

Просмотреть файл

Просмотреть файл

@ -0,0 +1,177 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""DetCon-B: DetCon implementation for BYOL."""
from torch.nn.modules import Module
import torch.nn.functional as F
import numpy as np
from torchgeo.Trainer import byol
Module.__module__ = "torch.nn"
def featurewise_std(x: np.ndarray) -> np.ndarray:
"""Computes the featurewise standard deviation."""
return np.mean(np.std(x, axis=0))
def compute_fh_segmentation(image_np, scale, min_size):
"""Compute FSZ segmentation on image and record stats."""
segmented_image = skimage.segmentation.felzenszwalb(
image_np, scale=scale, min_size=min_size)
segmented_image = segmented_image.astype(np.dtype('<u1'))
return segmented_image
class DetConB(Module):
"""DetCon-B's training component definition."""
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
assert self.hparams["training_mode"] in ['self-supervised', 'supervised', 'both']
self.training_mode = self.hparams["training_mode"]
def __init__(self, mode: Text, model: Module,
image_size: Tuple[int, int] = (224, 224),
hidden_layer: Union[str, int] = -2,
input_channels: int = 3,
projection_size: int = 256,
hidden_size: int = 4096,
augment_fn: Optional[Module] = None,
beta: float = 0.99, **kwargs: Any):
"""Constructs the experiment.
Args:
mode: A string, equivalent to FLAGS.mode when running normally.
config: Experiment configuration.
"""
super().__init__(mode, kwargs)
self.augment: Module
if augment_fn is None:
self.augment = byol.SimCLRAugmentation(image_size)
else:
self.augment = augment_fn
self.beta = beta
self.input_channels = input_channels
self.encoder = byol.EncoderWrapper(
model, projection_size, hidden_size, layer=hidden_layer
)
self.predictor = byol.MLP(projection_size, projection_size, hidden_size)
self._target: Optional[Module] = None
# Perform a single forward pass to initialize the wrapper correctly
self.encoder(
torch.zeros( # type: ignore[attr-defined]
2, self.input_channels, *image_size
)
)
self.config_task()
def create_binary_mask(
self,
batch_size,
num_pixels,
masks,
max_mask_id=256,
downsample=(1, 32, 32, 1)):
"""Generates binary masks from the Felzenszwalb masks.
From a FH mask of shape [batch_size, H,W] (values in range
[0,max_mask_id], produces corresponding (downsampled) binary masks of
shape [batch_size, max_mask_id, H*W/downsample].
Args:
batch_size: batch size of the masks
num_pixels: Number of points on the spatial grid
masks: Felzenszwalb masks
max_mask_id: # unique masks in Felzenszwalb segmentation
downsample: rate at which masks must be downsampled.
Returns:
binary_mask: Binary mask with specification above
"""
fh_mask_to_use = self.hparams["fh_mask_to_use"]
mask = masks[..., fh_mask_to_use:(fh_mask_to_use+1)]
mask_ids = np.arange(max_mask_id).reshape(1, 1, 1, max_mask_id)
binary_mask = np.equal(mask_ids, mask).astype('float32')
binary_mask = F.avg_pool2d(binary_mask, downsample, downsample, count_include_pad=False)
binary_mask = binary_mask.reshape(batch_size, num_pixels, max_mask_id)
binary_mask = np.argmax(binary_mask, axis=-1)
binary_mask = np.eye(max_mask_id)[binary_mask]
binary_mask = np.transpose(binary_mask, [0, 2, 1])
return binary_mask
def sample_masks(self, binary_mask, batch_size, n_random_vectors=16):
"""Samples which binary masks to use in the loss."""
mask_exists = np.greater(binary_mask.sum(-1), 1e-3)
sel_masks = mask_exists.astype('float32') + 0.00000000001
sel_masks = sel_masks / sel_masks.sum(1, keepdims=True)
sel_masks = np.log(sel_masks)
mask_ids = np.random.choice(
np.arange(len(sel_masks[-1])), p=sel_masks,
shape=tuple([n_random_vectors, batch_size]))
mask_ids = np.transpose(mask_ids, [1, 0])
smpl_masks = np.stack(
[binary_mask[b][mask_ids[b]] for b in range(batch_size)])
return smpl_masks, mask_ids
@property
def target(self) -> Module:
"""The "target" model."""
if self._target is None:
self._target = deepcopy(self.encoder)
return self._target
def update_target(self) -> None:
"""Method to update the "target" model weights."""
for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the encoder model through the MLP and prediction head.
Args:
x: tensor of data to run through the model
Returns:
output from the model
"""
return cast(Tensor, self.predictor(self.encoder(x)))
def run_detcon_b_forward_on_view(
view_encoder: Any,
projector: Any,
predictor: Any,
classifier: Any,
is_training: bool,
images: np.ndarray,
masks: np.ndarray,
suffix: Text = '',
):
pass
def _forward(
self,
inputs: image_dataset.Batch,
is_training: bool,
) -> Mapping[Text, np.ndarray]:
"""Forward application of byol's architecture.
Args:
inputs: A batch of data, i.e. a dictionary, with either two keys,
(`images` and `labels`) or three keys (`view1`, `view2`, `labels`).
is_training: Training or evaluating the model? When True, inputs must
contain keys `view1` and `view2`. When False, inputs must contain key
`images`.
Returns:
All outputs of the model, i.e. a dictionary with projection, prediction
and logits keys, for either the two views, or the image.
"""
pass

Просмотреть файл

@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Loss functions for detcon B."""
import numpy as np
import torch.nn.functional as F
from torchvision.trainers.detcon import detcon_utils
def manual_cross_entropy(labels, logits, weight):
"""Manually computes crossentropy loss.
Args:
labels: tensor labels
logits: tensor logits
Returns:
crossentropy loss
"""
ce = - weight * np.sum(labels * F.log_softmax(logits), axis=-1)
return np.mean(ce)
def l2_normalize( x: np.ndarray,
axis: Optional[int] = None,
epsilon: float = 1e-12,
) -> np.ndarray:
"""l2 normalize a tensor on an axis with numerical stability."""
square_sum = np.sum(np.square(x), axis=axis, keepdims=True)
x_inv_norm = 1/(np.sqrt(np.maximum(square_sum, epsilon)))
return x * x_inv_norm
def byol_nce_detcon(pred1, pred2, target1, target2,
pind1, pind2, tind1, tind2,
temperature=0.1, use_replicator_loss=True,
local_negatives=True):
"""Compute the NCE scores from pairs of predictions and targets.
This implements the batched form of the loss described in
Section 3.1, Equation 3 in https://arxiv.org/pdf/2103.10957.pdf.
Args:
pred1 (np.array): the prediction from first view.
pred2 (np.array): the prediction from second view.
target1 (np.array): the projection from first view.
target2 (np.array): the projection from second view.
pind1 (np.array): mask indices for first view's prediction.
pind2 (np.array): mask indices for second view's prediction.
tind1 (np.array): mask indices for first view's projection.
tind2 (np.array): mask indices for second view's projection.
temperature (float): the temperature to use for the NCE loss.
use_replicator_loss (bool): use cross-replica samples.
local_negatives (bool): whether to include local negatives
Returns:
A single scalar loss for the XT-NCE objective.
"""
batch_size = pred1.shape[0]
num_rois = pred1.shape[1]
feature_dim = pred1.shape[-1]
infinity_proxy = 1e9 # Used for masks to proxy a very large number.
def make_same_obj(ind_0, ind_1):
same_obj = np.equal(ind_0.reshape([batch_size, num_rois, 1]),
ind_1.reshape([batch_size, 1, num_rois]))
return np.expand_dims(same_obj.astype("float32"), axis=2)
same_obj_aa = make_same_obj(pind1, tind1)
same_obj_ab = make_same_obj(pind1, tind2)
same_obj_ba = make_same_obj(pind2, tind1)
same_obj_bb = make_same_obj(pind2, tind2)
# L2 normalize the tensors to use for the cosine-similarity
pred1 = l2_normalize(pred1, axis=-1)
pred2 = l2_normalize(pred2, axis=-1)
target1 = l2_normalize(target1, axis=-1)
target2 = l2_normalize(target2, axis=-1)
#Just work for a sungle GPU for now
target1_large = target1
target2_large = target2
labels_local = F.one_hot(np.arange(batch_size), batch_size)
labels_ext = F.one_hot(np.arange(batch_size), batch_size * 2)
labels_local = np.expand_dims(np.expand_dims(labels_local, axis=2), axis=1)
labels_ext = np.expand_dims(np.expand_dims(labels_ext, axis=2), axis=1)
# Do our matmuls and mask out appropriately.
logits_aa = np.einsum("abk,uvk->abuv", pred1, target1_large) / temperature
logits_bb = np.einsum("abk,uvk->abuv", pred2, target2_large) / temperature
logits_ab = np.einsum("abk,uvk->abuv", pred1, target2_large) / temperature
logits_ba = np.einsum("abk,uvk->abuv", pred2, target1_large) / temperature
labels_aa = labels_local * same_obj_aa
labels_ab = labels_local * same_obj_ab
labels_ba = labels_local * same_obj_ba
labels_bb = labels_local * same_obj_bb
logits_aa = logits_aa - infinity_proxy * labels_local * same_obj_aa
logits_bb = logits_bb - infinity_proxy * labels_local * same_obj_bb
labels_aa = 0. * labels_aa
labels_bb = 0. * labels_bb
if not local_negatives:
logits_aa = logits_aa - infinity_proxy * labels_local * (1 - same_obj_aa)
logits_ab = logits_ab - infinity_proxy * labels_local * (1 - same_obj_ab)
logits_ba = logits_ba - infinity_proxy * labels_local * (1 - same_obj_ba)
logits_bb = logits_bb - infinity_proxy * labels_local * (1 - same_obj_bb)
labels_abaa = np.concatenate([labels_ab, labels_aa], axis=2)
labels_babb = np.concatenate([labels_ba, labels_bb], axis=2)
labels_0 = np.reshape(labels_abaa, [batch_size, num_rois, -1])
labels_1 = np.reshape(labels_babb, [batch_size, num_rois, -1])
num_positives_0 = np.sum(labels_0, axis=-1, keepdims=True)
num_positives_1 = np.sum(labels_1, axis=-1, keepdims=True)
labels_0 = labels_0 / np.maximum(num_positives_0, 1)
labels_1 = labels_1 / np.maximum(num_positives_1, 1)
obj_area_0 = np.sum(make_same_obj(pind1, pind1), axis=[2, 3])
obj_area_1 = np.sum(make_same_obj(pind2, pind2), axis=[2, 3])
weights_0 = np.greater(num_positives_0[..., 0], 1e-3).astype("float32")
weights_0 = weights_0 / obj_area_0
weights_1 = np.greater(num_positives_1[..., 0], 1e-3).astype("float32")
weights_1 = weights_1 / obj_area_1
logits_abaa = np.concatenate([logits_ab, logits_aa], axis=2)
logits_babb = np.concatenate([logits_ba, logits_bb], axis=2)
logits_abaa = np.reshape(logits_abaa, [batch_size, num_rois, -1])
logits_babb = np.reshape(logits_babb, [batch_size, num_rois, -1])
loss_a = manual_cross_entropy(labels_0, logits_abaa, weights_0)
loss_b = manual_cross_entropy(labels_1, logits_babb, weights_1)
loss = loss_a + loss_b
return loss