This commit is contained in:
Daniela Massiceti 2022-12-16 06:32:37 +00:00
Родитель 33cb2fdeee
Коммит 83098c9bcd
1 изменённых файлов: 4 добавлений и 3 удалений

Просмотреть файл

@ -33,7 +33,6 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from timm.models.efficientnet import EfficientNet
from timm.models.layers.norm_act import BatchNormAct2d
from timm.models.efficientnet_blocks import ConvBnAct, InvertedResidual, CondConvResidual, EdgeResidual
def tag_film_layers(feature_extractor_name, feature_extractor):
@ -49,7 +48,8 @@ def tag_film_layers(feature_extractor_name, feature_extractor):
modules_to_tag = []
for child_module_name in dir(module):
child_module = getattr(module, child_module_name)
if child_module_name in modules_to_tag and isinstance(child_module, BatchNormAct2d):
child_module_type = type(child_module)
if child_module_name in modules_to_tag and issubclass(child_module_type, nn.BatchNorm2d):
child_module.film = True
for name, child in module.named_children():
recursive_tag(child, name)
@ -58,7 +58,8 @@ def tag_film_layers(feature_extractor_name, feature_extractor):
def recursive_tag(module, name):
for child_module_name in dir(module):
child_module = getattr(module, child_module_name)
if child_module_name in ['norm', 'norm1', 'norm2']:
child_module_type = type(child_module)
if child_module_name in ['norm', 'norm1', 'norm2'] and (issubclass(child_module_type, nn.LayerNorm) or issubclass(child_module_type, nn.GroupNorm)):
child_module.film = True
for name, child in module.named_children():
recursive_tag(child, name)