본문 바로가기
AI/Development

[PytorchToC++] 02. TorchScript 분석

by _S0_H2_ 2021. 9. 29.
728x90
반응형

torch_tracing과 torchscript 비교 에 이어진다.

 

1. TorchScript 분석

Module을 아래와 같이 구성하면 재사용성과 가독성이 높아진다. MyDecisionGate는 제어 흐름을 활용한다. 아래에 출력되는 grad_fn은 잠재적으로 복잡한 프로그램을 통해 미분을 계산할 수 있게된다. 코드 중 미분값을 명시적으로 정의할 필요가 없는 경우도 있는데 pytorch는 변화도 테이프를 사용하여 연산이 발생할 때만 이를 기록하고 미분값을 계산할 때 거꾸로 재생한다. (많은 프레임워크들이 프로그램 코드로부터 기호식 미분을 계산하는 접근법을 취하고 있음)

 

1 ) 모듈 1 : torch.nn.Module 을 상속받음

import torch
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        # 제어 흐름을 활용
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):  # 생성자 정의
        super(MyCell, self).__init__() # 우선적으로 생성자 호출
        self.dg = MyDecisionGate() # 실제 모델에서 'loop'에서 적용 가능한 함수
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

my_cell = MyCell()
print(my_cell)

x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.forward(x, h))

 위의 코드는 pytorch tutorial에서 확인한 코드인데 작성하다보니 torch.jit.ScriptModule를 사용하였을 때의 차이점이 무엇인지 의문점이 생겼다.

 

2 ) 모듈2 : torch.jit.ScriptModule을 상속받음

class MyDecisionGate2(torch.nn.Module):
    def forward(self, x):
        # 제어 흐름을 활용
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell2(torch.jit.ScriptModule):
    def __init__(self):  # 생성자 정의
        super(MyCell2, self).__init__() # 우선적으로 생성자 호출
        self.dg = MyDecisionGate2() # 실제 모델에서 'loop'에서 적용 가능한 함수
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

my_cell = MyCell2()
print(my_cell)

x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.forward(x, h))

 

MyCell이 torch.jit.ScriptModule을 상속받았다. 이 때 모델은 다음과 같이 기록된다.

여기에 추가로 @torch.jit.script_method를 forwar에 더해준다.

3 ) torch.jit.scriptmodule을 상속 받고 decorator를 사용함

class MyCell3(torch.jit.ScriptModule):
    def __init__(self):  # 생성자 정의
        super(MyCell3, self).__init__() # 우선적으로 생성자 호출
        self.dg = MyDecisionGate3() # 실제 모델에서 'loop'에서 적용 가능한 함수
        self.linear = torch.nn.Linear(4, 4)

    @torch.jit.script_method
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h

my_cell = MyCell3()
print(my_cell)

x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.forward(x, h))

이 때는 autograd부분이 변경된다.

TanhBackward와 DifferentiableGraphBackward가 수행하는 방식에서 어떤 차이점이 있는지 확인이 필요하다.

 

이를 위해 세가지 버전을 모두 저장한 뒤 C++에서 확인해보고자 한다. 다음 코드로 모듈을 저장한다.

# 각각의 scripted_cell 밑에 저장한다
scripted_cell.save('MyCell1.pt')
scripted_cell.save('MyCell2.pt')
scripted_cell.save('MyCell3.pt')

 

2. C++ 에서 비교

Visual Studio에 libtorch를 설치한 뒤 위의 모듈을 load하여 실행해보자.

 

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>

using namespace std;

int main()
{
	torch::jit::script::Module module1;
	torch::jit::script::Module module2;
	torch::jit::script::Module module3;

	try {
		module1 = torch::jit::load("C:/Project/PytorchTest/MyCell1.pt");
		module2 = torch::jit::load("C:/Project/PytorchTest/MyCell2.pt");
		module3 = torch::jit::load("C:/Project/PytorchTest/MyCell3.pt");
		std::cout << "module loaded" << std::endl;
	}
	catch(const c10::Error& e){
		std::cerr << "Error loading model\n";
		return -1;
	}

	torch::Tensor x = torch::rand ({ 3, 4 });
	torch::Tensor h = torch::rand({ 3, 4 });

	torch::Tensor out_tensor1 = module1.forward({ x, h }).toTensor();
	std::cout << "result : " << out_tensor1 << std::endl;

	torch::Tensor out_tensor3 = module3.forward({ x, h }).toTensor();
	std::cout << "result : " << out_tensor3 << std::endl;

	torch::Tensor out_tensor2 = module2.forward({ x, h }).toTensor();
	std::cout << "result : " << out_tensor2 << std::endl;

	return 0;
}

 

1, 3번 모듈은 실행이 가능한데 2번 모듈은 실행이 안된다. 이 부분은 코드 더 깊이 봐야 알 것 같다. 

잠정적 결론 : Class를 생성할 때 torch.jit.ScriptModule을 상속받으면 forward에는 decorator(@)를 함께하자 !

 

 

 

 

 

728x90
반응형