reverted back to old, learnable set encoder, rather than frozen, pre-trained one

This commit is contained in:
Daniela Massiceti 2023-03-27 04:36:21 +01:00
Родитель 861e25ed68
Коммит 78ce4534f9
1 изменённых файлов: 46 добавлений и 16 удалений

Просмотреть файл

@ -30,33 +30,22 @@ SOFTWARE.
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.feature_extractors import create_feature_extractor
class SetEncoder(nn.Module):
"""
Simple set encoder implementing DeepSets (https://arxiv.org/abs/1703.06114). Used for modeling permutation-invariant representations on sets (mainly for extracting task-level embedding of context sets).
"""
def __init__(self, encoder_name='efficientnet_b0', task_embedding_dim=1280):
def __init__(self):
"""
Creates an instance of SetEncoder.
:return: Nothing.
"""
super(SetEncoder, self).__init__()
self.task_embedding_dim = task_embedding_dim
self.encoder_name = encoder_name
self.encoder, _ = create_feature_extractor(
feature_extractor_name=self.encoder_name,
pretrained=True,
with_film=False,
learn_extractor=False
)
self.encoder = SimplePrePoolNet()
def forward(self, x):
"""
Function that encodes a set of N elements into N embeddings, each of dim self.task_embedding_dim
Function that encodes a set of N elements into N embeddings, each of dim 64
:param x: (torch.Tensor) Set of elements (for clips it has the shape: batch x clip length x C x H x W).
:return: (torch.Tensor) Individual element embeddings.
"""
@ -87,8 +76,49 @@ class SetEncoder(nn.Module):
@property
def output_size(self):
return self.task_embedding_dim
return 64
class SimplePrePoolNet(nn.Module):
"""
Simple network to encode elements of a set into low-dimensional embeddings. Used before pooling them to obtain a task-level embedding. A multi-layer convolutional network is used, similar to that in https://github.com/cambridge-mlg/cnaps.
"""
def __init__(self):
"""
Creates an instance of SimplePrePoolNet.
:return: Nothing.
"""
super(SimplePrePoolNet, self).__init__()
self.layer1 = self._make_conv2d_layer(3, 64)
self.layer2 = self._make_conv2d_layer(64, 64)
self.layer3 = self._make_conv2d_layer(64, 64)
self.layer4 = self._make_conv2d_layer(64, 64)
self.layer5 = self._make_conv2d_layer(64, 64)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
@staticmethod
def _make_conv2d_layer(in_maps, out_maps, kernel_size=3):
return nn.Sequential(
nn.Conv2d(in_maps, out_maps, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_maps),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False)
)
def forward(self, x):
"""
Function that encodes each element in x into a 64-dimensional embedding.
:param x: (torch.Tensor) Set of elements (for clips it has the shape: batch*clip length x C x H x W).
:return: (torch.Tensor) Each element in x encoded as a 64-dimensional vector.
"""
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class NullSetEncoder(nn.Module):
def __init__(self):
super().__init__()