본문 바로가기
AI/Development

[PytorchToC++] 01. TorchScript 비교 및 simple code review

by _S0_H2_ 2021. 9. 28.
728x90
반응형

<< 전체 작업 flow >>

Python에서 Pytorch로 작성한 모델을 C++ 환경에서 사용하기 위해 Develop을 진행한다.

Python에서 jit로 모델을 작성한 후 C++에서 LibTorch를 사용하여 로드하고자 한다.

 

1. python 환경 

python 환경에서 jit로 모델을 모듈화하는 과정은 두 가지 방법이 있다.

[ torch_tracing ]
: 입력값을 사용하여 모델 구조를 파악한 뒤 입력값의 모델 안에서의 흐름을 통해 모델을 기록한다. flow가 기록되기 때문에 statically fix된 그래프이다.

[ annotation(script) ]
: torchscript 컴파일러가 직접 모델 코드를 분석하여 컴파일을 진행한다. 따라서 dynamic한 control flow(조건 분기, break 등)를 사용할 수 있다. 하지만 지원하지 않는 python code 및 type 추정 문제가 있어 모델 convert시 확인이 필요하다.

1 ) torch_tracing

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(3, 3)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h


my_cell = MyCell()
x, h = torch.ones(2, 3), torch.zeros(2, 3)
scripted_cell = torch.jit.trace(my_cell, example_inputs=(x, h))

a, b = torch.ones(4, 3), torch.zeros(1, 3)
print(scripted_cell.forward(a, b))

torch_tracing은 입력값을 주고 그 흐름을 사용하기 때문에 example_inputs가 필수 parameter이다.

 

이 사이에 사용하고 싶은 새로운 함수를 정의할 때 유의할 점이 있다. __init__ 에서 생성한 모듈이 가진 함수만 접근이 가능하다는 점이다. 간단히 만들어본 함수이다.

parameter는 다음과 같이 무언가를 출력하는 반면에

torch.nn.Linear가 갖고 있지 않은 메소드는 에러가 발생한다.

 

재사용함수를 모델에 구현하기 위한 두 가지 방법이 있다.

첫번째는 위의 코드에 @torch.jit.export 를 추가하는 것이다.

@torch.jit.export
    def check(self, x):
        return x

이를 출력하면 다음과 같이 나온다.

 

두번째는 class를 정의하는 것이다.

class MyDecisionGate2(torch.nn.Module):
    def forward(self, x):
        # 제어 흐름을 활용
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell2(torch.nn.Module):
    def __init__(self):  # 생성자 정의
        super(MyCell2, self).__init__() # 우선적으로 생성자 호출
        self.dg = MyDecisionGate2() # 실제 모델에서 'loop'에서 적용 가능한 함수
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

my_cell = MyCell2()
print(my_cell)

x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

scripted_cell = torch.jit.trace(my_cell, (x, h))
print(scripted_cell.forward(x, h))

하지만 tracing의 경우 if, loop문이 있는 경우 함수를 사용하더라도 warning이 발생한다.

scripted_cell.code

위를 추가로 출력해보면 아래와 같은 Warning이 발생한다.

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if x.sum() > 0:

 

tracing의 경우 코드를 실행하고 발생하는 작업을 기록하며 정확하게 수행하는 scriptmodule을 구성하므로 제어 흐름이 지워지게 된다. 이 모듈을 제대로 나타내기 위해서는 python 코드를 직접 분석하여 TorchScript로 변환하는 script compiler를 사용할 수 있다.

 

더보기

script와 비교

 

class MyDecisionGate2(torch.nn.Module):
    def forward(self, x):
        # 제어 흐름을 활용
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell2(torch.nn.Module):
    def __init__(self, dg):  # 생성자 정의
        super(MyCell2, self).__init__() # 우선적으로 생성자 호출
        self.dg = dg # 실제 모델에서 'loop'에서 적용 가능한 함수
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

scripted_gate = torch.jit.script(MyDecisionGate2())
my_cell = MyCell2(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)

이 때 code 부분을 보면 분기가 들어감을 확인할 수 있다.

 

 

2 ) annotation ( torch_script )

첫번째 예시와 같은 구조를 사용하여 모델을 만들어보려고 한다. 다른 방법으로 torch.jit.ScriptModule을 상속받는 방법이 있다. 그리고 사용하고자 하는 함수에 decorator를 사용한다.

 

decorator알아보기

 

class MyCell(torch.jit.ScriptModule):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(3, 3)

    @torch.jit.script_method
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h

    @torch.jit.script_method
    def check(self, x):
        return x


my_cell = MyCell()
scripted_cell = torch.jit.script(my_cell)

a, b = torch.ones(4, 3), torch.zeros(1, 3)
print(scripted_cell.forward(a, b))
print(scripted_cell.check(a))

 

torchscript는 tracing과 다르게 example을 전달하지 않아도 된다. 함수 안에서도 함수를 사용하는 것도 가능해서 모델 구조를 작성할 때 아주 용이할 듯 하다!

class MyCell(torch.jit.ScriptModule):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(3, 3)

    @torch.jit.script_method
    def forward(self, x, h):
        new_h = torch.tanh(self.check(x) + h)
        return new_h

    @torch.jit.script_method
    def check(self, x):
        print("check function visited")
        return self.linear(x)

my_cell = MyCell()
a, b = torch.ones(4, 3), torch.zeros(1, 3)
scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.forward(a, b))

 

728x90
반응형

'AI > Development' 카테고리의 다른 글

[PytorchToC++] 03. Visual Studio에 LibTorch 설치하기  (0) 2021.09.29
[PytorchToC++] 02. TorchScript 분석  (0) 2021.09.29