reverted back to old, learnable set encoder, rather than frozen, pre-trained one
This commit is contained in:
Родитель
861e25ed68
Коммит
78ce4534f9
|
@ -30,33 +30,22 @@ SOFTWARE.
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.feature_extractors import create_feature_extractor
|
||||
|
||||
class SetEncoder(nn.Module):
|
||||
"""
|
||||
Simple set encoder implementing DeepSets (https://arxiv.org/abs/1703.06114). Used for modeling permutation-invariant representations on sets (mainly for extracting task-level embedding of context sets).
|
||||
"""
|
||||
def __init__(self, encoder_name='efficientnet_b0', task_embedding_dim=1280):
|
||||
def __init__(self):
|
||||
"""
|
||||
Creates an instance of SetEncoder.
|
||||
:return: Nothing.
|
||||
"""
|
||||
super(SetEncoder, self).__init__()
|
||||
self.task_embedding_dim = task_embedding_dim
|
||||
self.encoder_name = encoder_name
|
||||
|
||||
self.encoder, _ = create_feature_extractor(
|
||||
feature_extractor_name=self.encoder_name,
|
||||
pretrained=True,
|
||||
with_film=False,
|
||||
learn_extractor=False
|
||||
)
|
||||
|
||||
self.encoder = SimplePrePoolNet()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Function that encodes a set of N elements into N embeddings, each of dim self.task_embedding_dim
|
||||
Function that encodes a set of N elements into N embeddings, each of dim 64
|
||||
:param x: (torch.Tensor) Set of elements (for clips it has the shape: batch x clip length x C x H x W).
|
||||
:return: (torch.Tensor) Individual element embeddings.
|
||||
"""
|
||||
|
@ -87,8 +76,49 @@ class SetEncoder(nn.Module):
|
|||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self.task_embedding_dim
|
||||
return 64
|
||||
|
||||
class SimplePrePoolNet(nn.Module):
|
||||
"""
|
||||
Simple network to encode elements of a set into low-dimensional embeddings. Used before pooling them to obtain a task-level embedding. A multi-layer convolutional network is used, similar to that in https://github.com/cambridge-mlg/cnaps.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
Creates an instance of SimplePrePoolNet.
|
||||
:return: Nothing.
|
||||
"""
|
||||
super(SimplePrePoolNet, self).__init__()
|
||||
self.layer1 = self._make_conv2d_layer(3, 64)
|
||||
self.layer2 = self._make_conv2d_layer(64, 64)
|
||||
self.layer3 = self._make_conv2d_layer(64, 64)
|
||||
self.layer4 = self._make_conv2d_layer(64, 64)
|
||||
self.layer5 = self._make_conv2d_layer(64, 64)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
||||
|
||||
@staticmethod
|
||||
def _make_conv2d_layer(in_maps, out_maps, kernel_size=3):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_maps, out_maps, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(out_maps),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Function that encodes each element in x into a 64-dimensional embedding.
|
||||
:param x: (torch.Tensor) Set of elements (for clips it has the shape: batch*clip length x C x H x W).
|
||||
:return: (torch.Tensor) Each element in x encoded as a 64-dimensional vector.
|
||||
"""
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
class NullSetEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Загрузка…
Ссылка в новой задаче