<< 전체 작업 flow >>
Python에서 Pytorch로 작성한 모델을 C++ 환경에서 사용하기 위해 Develop을 진행한다.
Python에서 jit로 모델을 작성한 후 C++에서 LibTorch를 사용하여 로드하고자 한다.
1. python 환경
python 환경에서 jit로 모델을 모듈화하는 과정은 두 가지 방법이 있다.
[ torch_tracing ]
: 입력값을 사용하여 모델 구조를 파악한 뒤 입력값의 모델 안에서의 흐름을 통해 모델을 기록한다. flow가 기록되기 때문에 statically fix된 그래프이다.
[ annotation(script) ]
: torchscript 컴파일러가 직접 모델 코드를 분석하여 컴파일을 진행한다. 따라서 dynamic한 control flow(조건 분기, break 등)를 사용할 수 있다. 하지만 지원하지 않는 python code 및 type 추정 문제가 있어 모델 convert시 확인이 필요하다.
1 ) torch_tracing
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h
my_cell = MyCell()
x, h = torch.ones(2, 3), torch.zeros(2, 3)
scripted_cell = torch.jit.trace(my_cell, example_inputs=(x, h))
a, b = torch.ones(4, 3), torch.zeros(1, 3)
print(scripted_cell.forward(a, b))
torch_tracing은 입력값을 주고 그 흐름을 사용하기 때문에 example_inputs가 필수 parameter이다.
이 사이에 사용하고 싶은 새로운 함수를 정의할 때 유의할 점이 있다. __init__ 에서 생성한 모듈이 가진 함수만 접근이 가능하다는 점이다. 간단히 만들어본 함수이다.
parameter는 다음과 같이 무언가를 출력하는 반면에
torch.nn.Linear가 갖고 있지 않은 메소드는 에러가 발생한다.
재사용함수를 모델에 구현하기 위한 두 가지 방법이 있다.
첫번째는 위의 코드에 @torch.jit.export 를 추가하는 것이다.
@torch.jit.export
def check(self, x):
return x
이를 출력하면 다음과 같이 나온다.
두번째는 class를 정의하는 것이다.
class MyDecisionGate2(torch.nn.Module):
def forward(self, x):
# 제어 흐름을 활용
if x.sum() > 0:
return x
else:
return -x
class MyCell2(torch.nn.Module):
def __init__(self): # 생성자 정의
super(MyCell2, self).__init__() # 우선적으로 생성자 호출
self.dg = MyDecisionGate2() # 실제 모델에서 'loop'에서 적용 가능한 함수
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h
my_cell = MyCell2()
print(my_cell)
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
scripted_cell = torch.jit.trace(my_cell, (x, h))
print(scripted_cell.forward(x, h))
하지만 tracing의 경우 if, loop문이 있는 경우 함수를 사용하더라도 warning이 발생한다.
scripted_cell.code
위를 추가로 출력해보면 아래와 같은 Warning이 발생한다.
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if x.sum() > 0:
tracing의 경우 코드를 실행하고 발생하는 작업을 기록하며 정확하게 수행하는 scriptmodule을 구성하므로 제어 흐름이 지워지게 된다. 이 모듈을 제대로 나타내기 위해서는 python 코드를 직접 분석하여 TorchScript로 변환하는 script compiler를 사용할 수 있다.
script와 비교
class MyDecisionGate2(torch.nn.Module):
def forward(self, x):
# 제어 흐름을 활용
if x.sum() > 0:
return x
else:
return -x
class MyCell2(torch.nn.Module):
def __init__(self, dg): # 생성자 정의
super(MyCell2, self).__init__() # 우선적으로 생성자 호출
self.dg = dg # 실제 모델에서 'loop'에서 적용 가능한 함수
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h
scripted_gate = torch.jit.script(MyDecisionGate2())
my_cell = MyCell2(scripted_gate)
scripted_cell = torch.jit.script(my_cell)
print(scripted_gate.code)
print(scripted_cell.code)
이 때 code 부분을 보면 분기가 들어감을 확인할 수 있다.
2 ) annotation ( torch_script )
첫번째 예시와 같은 구조를 사용하여 모델을 만들어보려고 한다. 다른 방법으로 torch.jit.ScriptModule을 상속받는 방법이 있다. 그리고 사용하고자 하는 함수에 decorator를 사용한다.
class MyCell(torch.jit.ScriptModule):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(3, 3)
@torch.jit.script_method
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h
@torch.jit.script_method
def check(self, x):
return x
my_cell = MyCell()
scripted_cell = torch.jit.script(my_cell)
a, b = torch.ones(4, 3), torch.zeros(1, 3)
print(scripted_cell.forward(a, b))
print(scripted_cell.check(a))
torchscript는 tracing과 다르게 example을 전달하지 않아도 된다. 함수 안에서도 함수를 사용하는 것도 가능해서 모델 구조를 작성할 때 아주 용이할 듯 하다!
class MyCell(torch.jit.ScriptModule):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(3, 3)
@torch.jit.script_method
def forward(self, x, h):
new_h = torch.tanh(self.check(x) + h)
return new_h
@torch.jit.script_method
def check(self, x):
print("check function visited")
return self.linear(x)
my_cell = MyCell()
a, b = torch.ones(4, 3), torch.zeros(1, 3)
scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.forward(a, b))
'AI > Development' 카테고리의 다른 글
[PytorchToC++] 03. Visual Studio에 LibTorch 설치하기 (0) | 2021.09.29 |
---|---|
[PytorchToC++] 02. TorchScript 분석 (0) | 2021.09.29 |