зеркало из https://github.com/microsoft/esvit.git
47 строки
1.3 KiB
Python
47 строки
1.3 KiB
Python
|
|
from torch import nn
|
|
|
|
# from .batch_norm import FrozenBatchNorm2d
|
|
|
|
|
|
class CNNBlockBase(nn.Module):
|
|
"""
|
|
A CNN block is assumed to have input channels, output channels and a stride.
|
|
The input and output of `forward()` method must be NCHW tensors.
|
|
The method can perform arbitrary computation but must match the given
|
|
channels and stride specification.
|
|
|
|
Attribute:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, stride):
|
|
"""
|
|
The `__init__` method of any subclass should also contain these arguments.
|
|
|
|
Args:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.stride = stride
|
|
|
|
# def freeze(self):
|
|
# """
|
|
# Make this block not trainable.
|
|
# This method sets all parameters to `requires_grad=False`,
|
|
# and convert all BatchNorm layers to FrozenBatchNorm
|
|
|
|
# Returns:
|
|
# the block itself
|
|
# """
|
|
# for p in self.parameters():
|
|
# p.requires_grad = False
|
|
# FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
|
# return self
|