한 클래스 내에서 다른 클래스의 코드를 사용해야 할 때, 중복으로 코딩을 할 필요 없이 해당 클래스를 상속받을 수 있는 방법으로 클래스 상속이라는 방법이 있습니다.
예시를 들어보겠습니다. 파이토치를 사용하고 있는 중에 Convolution Layer와 BatchNorm Layer, Activation Function을 연결한 하나의 블록을 만들고자 합니다. 다만 저는 어떤 곳에서는 Conv1D, 다른 곳에서는 Conv2D 또는 Conv3D를 사용해야만 합니다. 이를 각각 구현했다면 3개의 블록을 구현한 각각의 클래스에 중복되는 코드들이 존재하게 될 것입니다.
이를 상속을 통해 빠르게 구현해보도록 하겠습니다.
실습에서 보여드릴 코드는 파이토치 공식 구현 코드인데 사용자 정의 함수가 하나 있어서 미리 선언해두도록 하겠습니다.
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
상속할 클래스 구현
우선 컨볼루전 레이어 + 배치노말라이제이션 레이어 + 액티베이션 레이어를 합친 블록 클래스를 생성합니다. 이 클래스를 상속받아 원하는 블록을 만들 계획입니다.
해당 코드는 파이토치 공식 레포의 코드입니다.
'''
파이토치 공식 구현 코드입니다.
https://github.com/pytorch/vision/blob/149edda463b54b3eabe989e260a839727c89d099/torchvision/ops/misc.py#L125
'''
import warnings
import collections
from typing import Callable, List, Optional, Sequence, Tuple, Union, Any
from itertools import repeat
import torch
from torch import nn, Tensor
class ConvNormActivation(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = 1,
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: Union[int, Tuple[int, ...]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None:
print("ConvNormActivation의 시작, 아마 두번째?")
if padding is None:
if isinstance(kernel_size, int) and isinstance(dilation, int):
padding = (kernel_size - 1) // 2 * dilation
else:
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
kernel_size = _make_ntuple(kernel_size, _conv_dim)
dilation = _make_ntuple(dilation, _conv_dim)
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
if bias is None:
bias = norm_layer is None
layers = [
conv_layer(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias=bias,
)
]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
params = {} if inplace is None else {"inplace": inplace}
layers.append(activation_layer(**params))
super().__init__(*layers)
self.out_channels = out_channels
if self.__class__ == ConvNormActivation:
warnings.warn(
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
)
print("ConvNormActivation의 끝, 아마 세번째?")
저는 구현 순서가 어떻게 진행되는지 궁금하므로 두 개의 프린트문을 넣었습니다.
상속받을 클래스 구현
Conv2D를 사용해서 위의 클래스를 상속받아 Conv - BatchNorm - ReLU6 블록을 구현하도록 하겠습니다.
이 코드 역시 파이토치에서 제공하는 코드입니다.
'''
파이토치 공식 구현 코드입니다.
https://github.com/pytorch/vision/blob/149edda463b54b3eabe989e260a839727c89d099/torchvision/ops/misc.py#L125
'''
import warnings
import collections
from typing import Callable, List, Optional, Sequence, Tuple, Union, Any
from itertools import repeat
import torch
from torch import nn, Tensor
class Conv2dNormActivation(ConvNormActivation):
"""
Configurable block used for Convolution2d-Normalization-Activation blocks.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: Union[int, Tuple[int, int]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
print("아마 첫 시작이겠죠?")
# 여기서 클래스 상속이 일어납니다.
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
torch.nn.Conv2d,
)
print("클래스 상속 후니 마지막이겠죠?")
이 코드를 보시면 어떠한 동작도 없이 상속만 받는 것을 확인할 수 있습니다. 과연 어떻게 위의 코드를 상속받아 블록을 구현할까요? 제 생각에는 super().__init__
이 클래스 상속을 하는 부분이니 프린트문에서 적은 대로 순서가 진행될 것 같습니다.
상속 테스트
norm_layer = nn.BatchNorm2d
Conv2dNormActivation(1, 3, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
위의 코드를 통해 Conv2d - BatchNorm2d - ReLU6 블록을 구현해보면 문제없이 블록을 구성하는 것을 확인하실 수 있습니다.
마지막으로 프린트 순서가 어떻게 진행되었는지 확인해보겠습니다.
아마 첫 시작이겠죠?
ConvNormActivation의 시작, 아마 두번째?
ConvNormActivation의 끝, 아마 세번째?
클래스 상속 후니 마지막이겠죠?
예상대로 Conv2dNormActivation에서 첫 프린트문이 나온 후 super().__init__
을 통해서 ConvNormActivation의 프린트문이 순서대로 나오고, Conv2dNormActivation의 마지막 프린트문이 나오는 걸 확인하실 수 있습니다.