зеркало из https://github.com/microsoft/torchgeo.git
Adding Detcon B loss and partial implementation
This commit is contained in:
Родитель
b3b3f7777c
Коммит
085578f776
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "ssl"
|
||||
name: "test_detcon"
|
||||
module:
|
||||
model: "detcon"
|
||||
encoder: "resnet18"
|
||||
input_channels: 3
|
||||
training_mode: "self-supervised"
|
||||
imagenet_pretraining: True
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "ssl"
|
||||
name: "test_detcon"
|
||||
module:
|
||||
model: "detcon"
|
||||
encoder: "resnet18"
|
||||
input_channels: 3
|
||||
training_mode: "self-supervised"
|
||||
imagenet_pretraining: True
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 64
|
||||
num_workers: 6
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""DetCon-B: DetCon implementation for BYOL."""
|
||||
from torch.nn.modules import Module
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torchgeo.Trainer import byol
|
||||
|
||||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
def featurewise_std(x: np.ndarray) -> np.ndarray:
|
||||
"""Computes the featurewise standard deviation."""
|
||||
return np.mean(np.std(x, axis=0))
|
||||
|
||||
|
||||
def compute_fh_segmentation(image_np, scale, min_size):
|
||||
"""Compute FSZ segmentation on image and record stats."""
|
||||
segmented_image = skimage.segmentation.felzenszwalb(
|
||||
image_np, scale=scale, min_size=min_size)
|
||||
segmented_image = segmented_image.astype(np.dtype('<u1'))
|
||||
return segmented_image
|
||||
|
||||
|
||||
class DetConB(Module):
|
||||
"""DetCon-B's training component definition."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
assert self.hparams["training_mode"] in ['self-supervised', 'supervised', 'both']
|
||||
self.training_mode = self.hparams["training_mode"]
|
||||
|
||||
|
||||
def __init__(self, mode: Text, model: Module,
|
||||
image_size: Tuple[int, int] = (224, 224),
|
||||
hidden_layer: Union[str, int] = -2,
|
||||
input_channels: int = 3,
|
||||
projection_size: int = 256,
|
||||
hidden_size: int = 4096,
|
||||
augment_fn: Optional[Module] = None,
|
||||
beta: float = 0.99, **kwargs: Any):
|
||||
"""Constructs the experiment.
|
||||
Args:
|
||||
mode: A string, equivalent to FLAGS.mode when running normally.
|
||||
config: Experiment configuration.
|
||||
"""
|
||||
super().__init__(mode, kwargs)
|
||||
|
||||
self.augment: Module
|
||||
if augment_fn is None:
|
||||
self.augment = byol.SimCLRAugmentation(image_size)
|
||||
else:
|
||||
self.augment = augment_fn
|
||||
|
||||
self.beta = beta
|
||||
self.input_channels = input_channels
|
||||
self.encoder = byol.EncoderWrapper(
|
||||
model, projection_size, hidden_size, layer=hidden_layer
|
||||
)
|
||||
self.predictor = byol.MLP(projection_size, projection_size, hidden_size)
|
||||
self._target: Optional[Module] = None
|
||||
|
||||
# Perform a single forward pass to initialize the wrapper correctly
|
||||
self.encoder(
|
||||
torch.zeros( # type: ignore[attr-defined]
|
||||
2, self.input_channels, *image_size
|
||||
)
|
||||
)
|
||||
|
||||
self.config_task()
|
||||
|
||||
def create_binary_mask(
|
||||
self,
|
||||
batch_size,
|
||||
num_pixels,
|
||||
masks,
|
||||
max_mask_id=256,
|
||||
downsample=(1, 32, 32, 1)):
|
||||
"""Generates binary masks from the Felzenszwalb masks.
|
||||
From a FH mask of shape [batch_size, H,W] (values in range
|
||||
[0,max_mask_id], produces corresponding (downsampled) binary masks of
|
||||
shape [batch_size, max_mask_id, H*W/downsample].
|
||||
Args:
|
||||
batch_size: batch size of the masks
|
||||
num_pixels: Number of points on the spatial grid
|
||||
masks: Felzenszwalb masks
|
||||
max_mask_id: # unique masks in Felzenszwalb segmentation
|
||||
downsample: rate at which masks must be downsampled.
|
||||
Returns:
|
||||
binary_mask: Binary mask with specification above
|
||||
"""
|
||||
fh_mask_to_use = self.hparams["fh_mask_to_use"]
|
||||
mask = masks[..., fh_mask_to_use:(fh_mask_to_use+1)]
|
||||
|
||||
mask_ids = np.arange(max_mask_id).reshape(1, 1, 1, max_mask_id)
|
||||
binary_mask = np.equal(mask_ids, mask).astype('float32')
|
||||
|
||||
binary_mask = F.avg_pool2d(binary_mask, downsample, downsample, count_include_pad=False)
|
||||
binary_mask = binary_mask.reshape(batch_size, num_pixels, max_mask_id)
|
||||
binary_mask = np.argmax(binary_mask, axis=-1)
|
||||
binary_mask = np.eye(max_mask_id)[binary_mask]
|
||||
binary_mask = np.transpose(binary_mask, [0, 2, 1])
|
||||
return binary_mask
|
||||
|
||||
def sample_masks(self, binary_mask, batch_size, n_random_vectors=16):
|
||||
"""Samples which binary masks to use in the loss."""
|
||||
mask_exists = np.greater(binary_mask.sum(-1), 1e-3)
|
||||
sel_masks = mask_exists.astype('float32') + 0.00000000001
|
||||
sel_masks = sel_masks / sel_masks.sum(1, keepdims=True)
|
||||
sel_masks = np.log(sel_masks)
|
||||
|
||||
mask_ids = np.random.choice(
|
||||
np.arange(len(sel_masks[-1])), p=sel_masks,
|
||||
shape=tuple([n_random_vectors, batch_size]))
|
||||
mask_ids = np.transpose(mask_ids, [1, 0])
|
||||
|
||||
smpl_masks = np.stack(
|
||||
[binary_mask[b][mask_ids[b]] for b in range(batch_size)])
|
||||
return smpl_masks, mask_ids
|
||||
|
||||
@property
|
||||
def target(self) -> Module:
|
||||
"""The "target" model."""
|
||||
if self._target is None:
|
||||
self._target = deepcopy(self.encoder)
|
||||
return self._target
|
||||
|
||||
def update_target(self) -> None:
|
||||
"""Method to update the "target" model weights."""
|
||||
for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
|
||||
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward pass of the encoder model through the MLP and prediction head.
|
||||
|
||||
Args:
|
||||
x: tensor of data to run through the model
|
||||
|
||||
Returns:
|
||||
output from the model
|
||||
"""
|
||||
return cast(Tensor, self.predictor(self.encoder(x)))
|
||||
|
||||
|
||||
def run_detcon_b_forward_on_view(
|
||||
view_encoder: Any,
|
||||
projector: Any,
|
||||
predictor: Any,
|
||||
classifier: Any,
|
||||
is_training: bool,
|
||||
images: np.ndarray,
|
||||
masks: np.ndarray,
|
||||
suffix: Text = '',
|
||||
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
inputs: image_dataset.Batch,
|
||||
is_training: bool,
|
||||
) -> Mapping[Text, np.ndarray]:
|
||||
"""Forward application of byol's architecture.
|
||||
Args:
|
||||
inputs: A batch of data, i.e. a dictionary, with either two keys,
|
||||
(`images` and `labels`) or three keys (`view1`, `view2`, `labels`).
|
||||
is_training: Training or evaluating the model? When True, inputs must
|
||||
contain keys `view1` and `view2`. When False, inputs must contain key
|
||||
`images`.
|
||||
Returns:
|
||||
All outputs of the model, i.e. a dictionary with projection, prediction
|
||||
and logits keys, for either the two views, or the image.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Loss functions for detcon B."""
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torchvision.trainers.detcon import detcon_utils
|
||||
|
||||
|
||||
def manual_cross_entropy(labels, logits, weight):
|
||||
"""Manually computes crossentropy loss.
|
||||
|
||||
Args:
|
||||
labels: tensor labels
|
||||
logits: tensor logits
|
||||
|
||||
Returns:
|
||||
crossentropy loss
|
||||
"""
|
||||
ce = - weight * np.sum(labels * F.log_softmax(logits), axis=-1)
|
||||
return np.mean(ce)
|
||||
|
||||
|
||||
def l2_normalize( x: np.ndarray,
|
||||
axis: Optional[int] = None,
|
||||
epsilon: float = 1e-12,
|
||||
) -> np.ndarray:
|
||||
"""l2 normalize a tensor on an axis with numerical stability."""
|
||||
square_sum = np.sum(np.square(x), axis=axis, keepdims=True)
|
||||
x_inv_norm = 1/(np.sqrt(np.maximum(square_sum, epsilon)))
|
||||
return x * x_inv_norm
|
||||
|
||||
|
||||
def byol_nce_detcon(pred1, pred2, target1, target2,
|
||||
pind1, pind2, tind1, tind2,
|
||||
temperature=0.1, use_replicator_loss=True,
|
||||
local_negatives=True):
|
||||
"""Compute the NCE scores from pairs of predictions and targets.
|
||||
This implements the batched form of the loss described in
|
||||
Section 3.1, Equation 3 in https://arxiv.org/pdf/2103.10957.pdf.
|
||||
Args:
|
||||
pred1 (np.array): the prediction from first view.
|
||||
pred2 (np.array): the prediction from second view.
|
||||
target1 (np.array): the projection from first view.
|
||||
target2 (np.array): the projection from second view.
|
||||
pind1 (np.array): mask indices for first view's prediction.
|
||||
pind2 (np.array): mask indices for second view's prediction.
|
||||
tind1 (np.array): mask indices for first view's projection.
|
||||
tind2 (np.array): mask indices for second view's projection.
|
||||
temperature (float): the temperature to use for the NCE loss.
|
||||
use_replicator_loss (bool): use cross-replica samples.
|
||||
local_negatives (bool): whether to include local negatives
|
||||
Returns:
|
||||
A single scalar loss for the XT-NCE objective.
|
||||
"""
|
||||
batch_size = pred1.shape[0]
|
||||
num_rois = pred1.shape[1]
|
||||
feature_dim = pred1.shape[-1]
|
||||
infinity_proxy = 1e9 # Used for masks to proxy a very large number.
|
||||
|
||||
def make_same_obj(ind_0, ind_1):
|
||||
same_obj = np.equal(ind_0.reshape([batch_size, num_rois, 1]),
|
||||
ind_1.reshape([batch_size, 1, num_rois]))
|
||||
return np.expand_dims(same_obj.astype("float32"), axis=2)
|
||||
same_obj_aa = make_same_obj(pind1, tind1)
|
||||
same_obj_ab = make_same_obj(pind1, tind2)
|
||||
same_obj_ba = make_same_obj(pind2, tind1)
|
||||
same_obj_bb = make_same_obj(pind2, tind2)
|
||||
|
||||
# L2 normalize the tensors to use for the cosine-similarity
|
||||
pred1 = l2_normalize(pred1, axis=-1)
|
||||
pred2 = l2_normalize(pred2, axis=-1)
|
||||
target1 = l2_normalize(target1, axis=-1)
|
||||
target2 = l2_normalize(target2, axis=-1)
|
||||
|
||||
#Just work for a sungle GPU for now
|
||||
|
||||
target1_large = target1
|
||||
target2_large = target2
|
||||
labels_local = F.one_hot(np.arange(batch_size), batch_size)
|
||||
labels_ext = F.one_hot(np.arange(batch_size), batch_size * 2)
|
||||
|
||||
labels_local = np.expand_dims(np.expand_dims(labels_local, axis=2), axis=1)
|
||||
labels_ext = np.expand_dims(np.expand_dims(labels_ext, axis=2), axis=1)
|
||||
|
||||
# Do our matmuls and mask out appropriately.
|
||||
logits_aa = np.einsum("abk,uvk->abuv", pred1, target1_large) / temperature
|
||||
logits_bb = np.einsum("abk,uvk->abuv", pred2, target2_large) / temperature
|
||||
logits_ab = np.einsum("abk,uvk->abuv", pred1, target2_large) / temperature
|
||||
logits_ba = np.einsum("abk,uvk->abuv", pred2, target1_large) / temperature
|
||||
|
||||
labels_aa = labels_local * same_obj_aa
|
||||
labels_ab = labels_local * same_obj_ab
|
||||
labels_ba = labels_local * same_obj_ba
|
||||
labels_bb = labels_local * same_obj_bb
|
||||
|
||||
logits_aa = logits_aa - infinity_proxy * labels_local * same_obj_aa
|
||||
logits_bb = logits_bb - infinity_proxy * labels_local * same_obj_bb
|
||||
labels_aa = 0. * labels_aa
|
||||
labels_bb = 0. * labels_bb
|
||||
if not local_negatives:
|
||||
logits_aa = logits_aa - infinity_proxy * labels_local * (1 - same_obj_aa)
|
||||
logits_ab = logits_ab - infinity_proxy * labels_local * (1 - same_obj_ab)
|
||||
logits_ba = logits_ba - infinity_proxy * labels_local * (1 - same_obj_ba)
|
||||
logits_bb = logits_bb - infinity_proxy * labels_local * (1 - same_obj_bb)
|
||||
|
||||
labels_abaa = np.concatenate([labels_ab, labels_aa], axis=2)
|
||||
labels_babb = np.concatenate([labels_ba, labels_bb], axis=2)
|
||||
|
||||
labels_0 = np.reshape(labels_abaa, [batch_size, num_rois, -1])
|
||||
labels_1 = np.reshape(labels_babb, [batch_size, num_rois, -1])
|
||||
|
||||
num_positives_0 = np.sum(labels_0, axis=-1, keepdims=True)
|
||||
num_positives_1 = np.sum(labels_1, axis=-1, keepdims=True)
|
||||
|
||||
labels_0 = labels_0 / np.maximum(num_positives_0, 1)
|
||||
labels_1 = labels_1 / np.maximum(num_positives_1, 1)
|
||||
|
||||
obj_area_0 = np.sum(make_same_obj(pind1, pind1), axis=[2, 3])
|
||||
obj_area_1 = np.sum(make_same_obj(pind2, pind2), axis=[2, 3])
|
||||
|
||||
weights_0 = np.greater(num_positives_0[..., 0], 1e-3).astype("float32")
|
||||
weights_0 = weights_0 / obj_area_0
|
||||
weights_1 = np.greater(num_positives_1[..., 0], 1e-3).astype("float32")
|
||||
weights_1 = weights_1 / obj_area_1
|
||||
|
||||
logits_abaa = np.concatenate([logits_ab, logits_aa], axis=2)
|
||||
logits_babb = np.concatenate([logits_ba, logits_bb], axis=2)
|
||||
|
||||
logits_abaa = np.reshape(logits_abaa, [batch_size, num_rois, -1])
|
||||
logits_babb = np.reshape(logits_babb, [batch_size, num_rois, -1])
|
||||
|
||||
loss_a = manual_cross_entropy(labels_0, logits_abaa, weights_0)
|
||||
loss_b = manual_cross_entropy(labels_1, logits_babb, weights_1)
|
||||
loss = loss_a + loss_b
|
||||
|
||||
return loss
|
Загрузка…
Ссылка в новой задаче