[PyTorch] PyTorch C++ Custom Operator 추가하기
이전 글에서 PyTorch의 C++ API를 이야기 하면서 PyTorch의 내부 구조가 어떻게 구성되어 있고, 거기와 관련된 기능들을 확인했었습니다. 이번에는 이전 글에서 이야기 하지 못한 TorchScript 와 C++ Extensions에 대해서 간단히 이야기 해보고, 나만의 Custom Operator를 추가하는 방법에 대해서 이야기 해보겠습니다.
TorchScript
TorchScript은PyTorch 1.0부터 새롭게 도입되었고, C++ 같은 고성능 환경에서 실행할 수 있는 PyTorch 모델의 중간 표현 입니다. Python의 하위 집합으로 TorchScript 컴파일러을 통해 구문 분석, 컴파일 및 최적화할 수 있습니다. 이를 통해 모델을 언어와 하드웨어에 독립적인 형태로 저장하고, 나중에 다른 환경(C++ 등등..)에서 불러와 실행할 수 있습니다. 또한 PyTorch에서는 동작의 유연성을 위해서 동적 그래프를 사용하지만, TorchScript는 정적 그래프(모델의 구조가 고정)를 사용하여, 매번 실행 시 그래프를 재구성해야 하므로 오버헤드가 발생하지 않습니다. 또한 모델이 고정되어 있기 전처리 단계에서 모델의 구조를 파악한 후에 이에 대한 최적화(Dead Code Elimination, Loop unrolling, ...)를 수행하게 됩니다. 따라서 모델을 더 빠르고 효율적으로 실행할 수 있게 해줍니다.
Torchscript의 C ++ 인터페이스에는 세 가지 기본 기능이 포함됩니다.
- Python에서 정의 된 직렬화 된 Torchscript 모델을 로드 및 실행
# save.py jit_net = torch.jit.script(net) torch.jit.save(jit_net, save_path) # load.py jit_net = torch.jit.load(load_path)
- Torchscript 표준 운영 라이브러리를 확장하는 사용자 정의 연산자를 정의하기위한 API
- C ++의 TorchScript 프로그램의 JIT(Just-in-time) 컴파일러
C++ Extensions
C++ Extensions은 PyTorch의 사용자 사용에 대한 확장을 위해 PyTorch의 모든 인터페이스에 간단하게 접근하는 방법을 제공합니다. 그렇기 때문에 C++ Extension API는 Pytorch C ++ API에 새로운 기능을 추가하거나 그러진 않고, 대신 Python Setupools와의 통합 및 Aten, Autograd 및 Python의 기타 C ++ API에 액세스 할 수있는 JIT 컴파일 메커니즘을 제공합니다.
즉, 사용자가 소스 외부에 정의된 PyTorch 연산자를 생성할 수 있도록 개발된 메커니즘입니다. 그러면 이제는 사용자 지정 연산자를 구현해 보면서 위에서 언급한 내용들을 확인해 보는 시간을 가져보겠습니다.
Custom Operator
Custom Operator는 아래의 문서에 굉장히 잘 표현되어 있습니다. 이번 시간에 아래의 문서를 참고하여, CPU, GPU에서 동작하는 간단한 Operation을 만들어 보겠습니다.
(참고로 아래의 예제는 pytorch 2.2.0a0버전, git commit(4cf97c40)에서 수행되었습니다.)
Extending TorchScript with Custom C++ Operators — PyTorch Tutorials 2.2.0+cu121 documentation
Extending TorchScript with Custom C++ Operators — PyTorch Tutorials 2.2.0+cu121 documentation
Extending TorchScript with Custom C++ Operators The PyTorch 1.0 release introduced a new programming model to PyTorch called TorchScript. TorchScript is a subset of the Python programming language which can be parsed, compiled and optimized by the TorchScr
pytorch.org
pytorch 프로젝트의 최상단에 새롭게 만들 custom operation을 작업할 폴더(my_op)를 생성해주고, 해당 폴더에 새롭게 정의할 my_shift_op.cu을 생성하였습니다. 이 operation은 input과 이동 시킬 값을 받아서 해당 값만큼 값들을 이동시키는 굉장히 간단한 operation입니다.
(참고로, 이 글에서는 shift 구현 코드에 대해서는 이야기 하지 않도록 하겠습니다.)
#include "my_shift_op.cuh"
#include <torch/script.h>
#include <ATen/cuda/detail/KernelUtils.h>
using namespace at::cuda::detail;
__global__ void gup_shift(float* input, float* output, int64_t move) {
// ...
}
void cpu_shift(float* input, float* ouput, int64_t move) {
// ...
}
torch::Tensor shift_op(torch::Tensor input, int64_t move) {
torch::Device device(torch::kCUDA, 0);
// 입력 데이터와 같은 크기의 output Tensor를 생성
torch::Tensor output = torch::zeros(input.size(0), torch::kFloat);
// output Tensor에 현재 입력의 디바이스정보를 주입
output = output.to(input);
// 현재 입력 Tensor의 디바이스에 따라서 분기
if (input.device() == device){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const dim3 grid(GET_BLOCKS(input.size(0)));
const dim3 block(CUDA_NUM_THREADS);
gup_shift<<<grid, block, 0, stream>>>(input.data_ptr<float>(), output.data_ptr<float>(), move);
} else {
cpu_shift(input.data_ptr<float>(), output.data_ptr<float>(), move);
}
return output;
}
위와 같이 custom operation 코드를 추가하였습니다. torch/script.h를 추가함으로써 사용자 정의 TorchScript 연산자를 구현하는 데 필요한 PyTorch의 C++ API 모든 기능들을 사용할 수 있습니다. 그리고 GET_BLOCKS, CUDA_NUM_THREADS와 같은 메크로 함수를 사용하기 위해서 ATen/cuda/detail/KernelUtils.h를 포함하였습니다.
참고로 함수의 인자들에 int, float과 같은 타입은 사용할 수 없습니다. 대신 int64_t, double을 사용해야 합니다.
만약 사용하게 된다면, 컴파일 과정에서 아래와 같은 에러가 발생하게 됩니다.
사용자 정의 연산자의 구현이 완료되면 TorchScript 런타임 및 컴파일러에 우리가 만든 사용자 정의 연산자를 등록해야 합니다. 이 작업을 통해서 TorchScript 컴파일러는 TorchScript 코드에서 사용자 정의 연산자에 대한 참조를 확인할 수 있게 됩니다. 저는 사용자 정의 연산자 구현코드와 같은 레벨에 아래의 코드를 추가해주었습니다.
(pybind11 구문과 매우 유사하여, pybind11를 사용해보신 분이라면 쉽게 이해가 가능합니다.)
#include <torch/torch.h>
#include <torch/script.h>
#include "my_shift_kernel.cuh"
TORCH_LIBRARY (my_shift_ops, m){
m.def("custom_shift", shift_op);
}
매크로( TORCH_LIBRARY )의 첫 번째 인자는 사용자 정의 연산자 라이브러리의 이름을 정의합니다. 이 이름은 따옴표로 표시 않고 지정합니다. (이 이름은 나중에 namespace로 사용됩니다.) 그리고 두 번째 인자 'm'은 연산자를 등록하는 데 사용되는 torch::Library에 바인딩됩니다.
현재는 등록할 사용자 정의 연산자가 하나라 하나만 등록했지만, def 함수를 여러 번 호출하여 원하는 만큼 많은 연산자를 정의할 수 있습니다. 참고로 def 함수는 템플릿 메타프로그래밍을 사용하여 함수의 유형 서명을 검사하고 이를 TorchScript 유형 시스템 내에서 연산자 유형을 지정하는 연산자 스키마로 변환하는 작업을 수행합니다.
위와 같이 코드 구현이 완료되었으면, 빌드를 위해 CMakeLists.txt파일을 작성해야 합니다.
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(my_shift_op LANGUAGES CXX CUDA)
find_package(Torch REQUIRED)
add_library(my_shift_op SHARED my_shift_op.cpp my_shift_kernel.cu)
// C++14을 타킷으로 설정
target_compile_features(my_shift_op PRIVATE cxx_std_14)
// LibTorch를 링크
target_link_libraries(my_shift_op "${TORCH_LIBRARIES}")
그 이후에 의존성(LibTorch)을 위해 PyTorch를 설치해야 합니다. 가장 쉽고 플랫폼 의존성없이 설치하는 방법은 Conda를 활용하는 것입니다.
conda install -c pytorch pytorch
설치 이후에 bulid라는 폴더를 생성하고, 이후에 build에서 아래의 명령어로 빌드를 수행할 수 있습니다.
(LibTorch의 cmake 파일들의 경로를 CMake에 알려주기 위해서 DCMAKE_PREFIX_PATH를 설정하고 있습니다.)
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
$ cmake --build .
빌드가 완료되면 build 폴더 안에 쉐어드 라이브러리(.so)가 제대로 생성된것을 볼 수 있습니다.
빌드가 완료되었으면 이제 우리가 구현된 사용자 정의 연산자가 잘 동작하는지 확인해 봐야 합니다. 가장 간단하게 사용자 정의 연산자가 잘 등록이 되었는지 확인 하는 방법은 아래와 같습니다.
이 밖에 Python, C++에서 구현한 연산자를 사용하여 테스트 하는 방법은 이 링크에서 확인할 수 있습니다.
또한 전체 코드는 이 링크에서 확인 가능합니다.
참고로, 우리는 지금까지 과정에서 back propagation을 위한 backward함수를 구현하지 않았습니다. 따라서 지금 구현한 custom operation은 추론만 하는 형태의 모델에서 사용하거나, 데이터 전처리나 후처리에서만 사용될 수 있습니다. 만약 학습에 사용될 수 있는 custom operation을 구현하고 싶다면 이 링크를 참고하시면 됩니다.
REFERENCE
TorchScript — PyTorch master documentation
TorchScript — PyTorch master documentation
TorchScript TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency. We provide tools to incrementally tran
pytorch.org
Custom C++ and CUDA Extensions — PyTorch Tutorials 2.2.0+cu121 documentation
Custom C++ and CUDA Extensions — PyTorch Tutorials 2.2.0+cu121 documentation
Custom C++ and CUDA Extensions Author: Peter Goldsborough PyTorch provides a plethora of operations related to neural networks, arbitrary tensor algebra, data wrangling and other purposes. However, you may still find yourself in need of a more customized o
pytorch.org