fixed film bug
This commit is contained in:
Родитель
33cb2fdeee
Коммит
83098c9bcd
|
@ -33,7 +33,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from timm.models.efficientnet import EfficientNet
|
||||
from timm.models.layers.norm_act import BatchNormAct2d
|
||||
from timm.models.efficientnet_blocks import ConvBnAct, InvertedResidual, CondConvResidual, EdgeResidual
|
||||
|
||||
def tag_film_layers(feature_extractor_name, feature_extractor):
|
||||
|
@ -49,7 +48,8 @@ def tag_film_layers(feature_extractor_name, feature_extractor):
|
|||
modules_to_tag = []
|
||||
for child_module_name in dir(module):
|
||||
child_module = getattr(module, child_module_name)
|
||||
if child_module_name in modules_to_tag and isinstance(child_module, BatchNormAct2d):
|
||||
child_module_type = type(child_module)
|
||||
if child_module_name in modules_to_tag and issubclass(child_module_type, nn.BatchNorm2d):
|
||||
child_module.film = True
|
||||
for name, child in module.named_children():
|
||||
recursive_tag(child, name)
|
||||
|
@ -58,7 +58,8 @@ def tag_film_layers(feature_extractor_name, feature_extractor):
|
|||
def recursive_tag(module, name):
|
||||
for child_module_name in dir(module):
|
||||
child_module = getattr(module, child_module_name)
|
||||
if child_module_name in ['norm', 'norm1', 'norm2']:
|
||||
child_module_type = type(child_module)
|
||||
if child_module_name in ['norm', 'norm1', 'norm2'] and (issubclass(child_module_type, nn.LayerNorm) or issubclass(child_module_type, nn.GroupNorm)):
|
||||
child_module.film = True
|
||||
for name, child in module.named_children():
|
||||
recursive_tag(child, name)
|
||||
|
|
Загрузка…
Ссылка в новой задаче