기타

클래스 상속, super() __init()

2022. 11. 7. 23:07
목차
  1.  
  2. 상속할 클래스 구현
  3.  
  4. 상속받을 클래스 구현
  5.  
  6. 상속 테스트
반응형

한 클래스 내에서 다른 클래스의 코드를 사용해야 할 때, 중복으로 코딩을 할 필요 없이 해당 클래스를 상속받을 수 있는 방법으로 클래스 상속이라는 방법이 있습니다.

예시를 들어보겠습니다. 파이토치를 사용하고 있는 중에 Convolution Layer와 BatchNorm Layer, Activation Function을 연결한 하나의 블록을 만들고자 합니다. 다만 저는 어떤 곳에서는 Conv1D, 다른 곳에서는 Conv2D 또는 Conv3D를 사용해야만 합니다. 이를 각각 구현했다면 3개의 블록을 구현한 각각의 클래스에 중복되는 코드들이 존재하게 될 것입니다.

이를 상속을 통해 빠르게 구현해보도록 하겠습니다.

실습에서 보여드릴 코드는 파이토치 공식 구현 코드인데 사용자 정의 함수가 하나 있어서 미리 선언해두도록 하겠습니다.

def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
    """
    Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
    Otherwise we will make a tuple of length n, all with value of x.
    reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
    Args:
        x (Any): input value
        n (int): length of the resulting tuple
    """
    if isinstance(x, collections.abc.Iterable):
        return tuple(x)
    return tuple(repeat(x, n))

 

상속할 클래스 구현

우선 컨볼루전 레이어 + 배치노말라이제이션 레이어 + 액티베이션 레이어를 합친 블록 클래스를 생성합니다. 이 클래스를 상속받아 원하는 블록을 만들 계획입니다.

해당 코드는 파이토치 공식 레포의 코드입니다.

'''
파이토치 공식 구현 코드입니다.
https://github.com/pytorch/vision/blob/149edda463b54b3eabe989e260a839727c89d099/torchvision/ops/misc.py#L125
'''

import warnings
import collections
from typing import Callable, List, Optional, Sequence, Tuple, Union, Any
from itertools import repeat

import torch
from torch import nn, Tensor

class ConvNormActivation(torch.nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, ...]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
        conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
    ) -> None:

        print("ConvNormActivation의 시작, 아마 두번째?")

        if padding is None:
            if isinstance(kernel_size, int) and isinstance(dilation, int):
                padding = (kernel_size - 1) // 2 * dilation
            else:
                _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
                kernel_size = _make_ntuple(kernel_size, _conv_dim)
                dilation = _make_ntuple(dilation, _conv_dim)
                padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
        if bias is None:
            bias = norm_layer is None

        layers = [
            conv_layer(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            params = {} if inplace is None else {"inplace": inplace}
            layers.append(activation_layer(**params))
        super().__init__(*layers)
        self.out_channels = out_channels

        if self.__class__ == ConvNormActivation:
            warnings.warn(
                "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
            )

        print("ConvNormActivation의 끝, 아마 세번째?")

저는 구현 순서가 어떻게 진행되는지 궁금하므로 두 개의 프린트문을 넣었습니다.

 

상속받을 클래스 구현

Conv2D를 사용해서 위의 클래스를 상속받아 Conv - BatchNorm - ReLU6 블록을 구현하도록 하겠습니다.

이 코드 역시 파이토치에서 제공하는 코드입니다.

'''
파이토치 공식 구현 코드입니다.
https://github.com/pytorch/vision/blob/149edda463b54b3eabe989e260a839727c89d099/torchvision/ops/misc.py#L125
'''

import warnings
import collections
from typing import Callable, List, Optional, Sequence, Tuple, Union, Any
from itertools import repeat

import torch
from torch import nn, Tensor

class Conv2dNormActivation(ConvNormActivation):
    """
    Configurable block used for Convolution2d-Normalization-Activation blocks.
    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
        kernel_size: (int, optional): Size of the convolving kernel. Default: 3
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
        dilation (int): Spacing between kernel elements. Default: 1
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = 3,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Optional[Union[int, Tuple[int, int], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, int]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
    ) -> None:

        print("아마 첫 시작이겠죠?")

        # 여기서 클래스 상속이 일어납니다.
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            norm_layer,
            activation_layer,
            dilation,
            inplace,
            bias,
            torch.nn.Conv2d,
        )

        print("클래스 상속 후니 마지막이겠죠?")

이 코드를 보시면 어떠한 동작도 없이 상속만 받는 것을 확인할 수 있습니다. 과연 어떻게 위의 코드를 상속받아 블록을 구현할까요? 제 생각에는 super().__init__이 클래스 상속을 하는 부분이니 프린트문에서 적은 대로 순서가 진행될 것 같습니다.

 

상속 테스트

norm_layer = nn.BatchNorm2d
Conv2dNormActivation(1, 3, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)

위의 코드를 통해 Conv2d - BatchNorm2d - ReLU6 블록을 구현해보면 문제없이 블록을 구성하는 것을 확인하실 수 있습니다.

마지막으로 프린트 순서가 어떻게 진행되었는지 확인해보겠습니다.

아마 첫 시작이겠죠?
ConvNormActivation의 시작, 아마 두번째?
ConvNormActivation의 끝, 아마 세번째?
클래스 상속 후니 마지막이겠죠?

예상대로 Conv2dNormActivation에서 첫 프린트문이 나온 후 super().__init__ 을 통해서 ConvNormActivation의 프린트문이 순서대로 나오고, Conv2dNormActivation의 마지막 프린트문이 나오는 걸 확인하실 수 있습니다.

반응형
저작자표시
  1.  
  2. 상속할 클래스 구현
  3.  
  4. 상속받을 클래스 구현
  5.  
  6. 상속 테스트
'기타' 카테고리의 다른 글
  • Tensorflow, Keras to TFJS 컨버팅
새우까앙
새우까앙
뉴비 분석가
새우까앙
새우위키
새우까앙
전체
오늘
어제
  • 전체보기 (64)
    • 이론 (42)
      • LLM (6)
      • Diffusion (3)
      • ML 기초 (10)
      • DL 기초 (6)
      • GAN (4)
      • 논문 리뷰 (10)
      • 분석뉴비가 알면 좋은 것 (3)
    • 기타 (12)
      • Pandas (1)
      • Matplotlib (1)
      • Airflow (5)
      • Huggingface (2)
      • Git (1)
    • 대회 (0)
    • 세팅 (5)
      • RaspberryPi (2)
      • M1 (2)
      • Tistory (1)
    • 게임 (2)
      • 로스트아크 (2)
    • 일상 (1)

블로그 메뉴

  • 홈
  • 태그

공지사항

  • 소개

인기 글

태그

  • LLM
  • 벡터스토어
  • Stable Video Diffusion
  • GAN
  • Video LDM
  • deepseek
  • 에어플로우
  • vectorstore
  • 스테이블 디퓨전
  • LLaMA
  • airflow
  • 인라인 수학기호좀 쓰고싶어요
  • retriever
  • 로스트아크
  • 디퓨전
  • 딥시크
  • Stable Diffusion
  • 논문리뷰
  • 랭체인
  • Diffusion

최근 댓글

최근 글

hELLO · Designed By 정상우.
새우까앙
클래스 상속, super() __init()
상단으로

티스토리툴바

개인정보

  • 티스토리 홈
  • 포럼
  • 로그인

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.