diff --git a/model/set_encoders.py b/model/set_encoders.py index d9ad189..f92afc5 100644 --- a/model/set_encoders.py +++ b/model/set_encoders.py @@ -30,33 +30,22 @@ SOFTWARE. import torch import torch.nn as nn -import torch.nn.functional as F - -from model.feature_extractors import create_feature_extractor class SetEncoder(nn.Module): """ Simple set encoder implementing DeepSets (https://arxiv.org/abs/1703.06114). Used for modeling permutation-invariant representations on sets (mainly for extracting task-level embedding of context sets). """ - def __init__(self, encoder_name='efficientnet_b0', task_embedding_dim=1280): + def __init__(self): """ Creates an instance of SetEncoder. :return: Nothing. """ super(SetEncoder, self).__init__() - self.task_embedding_dim = task_embedding_dim - self.encoder_name = encoder_name - - self.encoder, _ = create_feature_extractor( - feature_extractor_name=self.encoder_name, - pretrained=True, - with_film=False, - learn_extractor=False - ) - + self.encoder = SimplePrePoolNet() + def forward(self, x): """ - Function that encodes a set of N elements into N embeddings, each of dim self.task_embedding_dim + Function that encodes a set of N elements into N embeddings, each of dim 64 :param x: (torch.Tensor) Set of elements (for clips it has the shape: batch x clip length x C x H x W). :return: (torch.Tensor) Individual element embeddings. """ @@ -87,8 +76,49 @@ class SetEncoder(nn.Module): @property def output_size(self): - return self.task_embedding_dim + return 64 +class SimplePrePoolNet(nn.Module): + """ + Simple network to encode elements of a set into low-dimensional embeddings. Used before pooling them to obtain a task-level embedding. A multi-layer convolutional network is used, similar to that in https://github.com/cambridge-mlg/cnaps. + """ + def __init__(self): + """ + Creates an instance of SimplePrePoolNet. + :return: Nothing. + """ + super(SimplePrePoolNet, self).__init__() + self.layer1 = self._make_conv2d_layer(3, 64) + self.layer2 = self._make_conv2d_layer(64, 64) + self.layer3 = self._make_conv2d_layer(64, 64) + self.layer4 = self._make_conv2d_layer(64, 64) + self.layer5 = self._make_conv2d_layer(64, 64) + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + + @staticmethod + def _make_conv2d_layer(in_maps, out_maps, kernel_size=3): + return nn.Sequential( + nn.Conv2d(in_maps, out_maps, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_maps), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False) + ) + + def forward(self, x): + """ + Function that encodes each element in x into a 64-dimensional embedding. + :param x: (torch.Tensor) Set of elements (for clips it has the shape: batch*clip length x C x H x W). + :return: (torch.Tensor) Each element in x encoded as a 64-dimensional vector. + """ + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + class NullSetEncoder(nn.Module): def __init__(self): super().__init__()