Torch already provides a whole bunch of components covering the most commonly used operations for (convolutional) neural networks — convolutions, pooling, up-/downsampling, batch normalization ... Still, it happens that a more specialized component is needed. This might happen, for example, in research or for specific applications. While chances are good that Torch's community already provides implementations for well-known papers, it might still be necessary to write components oneself.
In most cases, components can easily be implemented in LUA. A short primer is included in the official Torch documentation; however, many examples can be found all over the web (especially on GitHub) or in my article of Torch examples. Sometimes, however, this is not enough; often, specific components cannot efficiently be implemented using existing Torch operations. In this article, I want to provide an easy way of extending Torch using C and CUDA without requiring any knowledge about Torch's inner workings.
The code corresponding to this article can also be found on GitHub:
Example on GitHubOverview
Independent of Torch, LUA can be extended using C/CUDA code using the ffi library. This package allows to load shared object files and use the functions defined therein from LUA. However, this way of interfacing between LUA and C only supports the standard types — so no Torch tensors.
Luckily, Torch allows to directly access the underlying data of a tensor using the :data()
method. As stated in the documentation:
Returns a LuaJIT FFI pointer to the raw data of the tensor.
Putting both together allows to implement C/CUDA functions operating on raw C array and providing this functionality in LUA given the corresponding shared object file, which can be loaded at runtime.
C/CUDA Template
The C template is rather simple as it merely represents a template of a possible Torch operation. For a simple identity module, the C code is split into:
extern "C" { void test_identity_module_updateOutput(const int rank, const long* dims, const float* input, float* output); void test_identity_module_updateGradInput(const int rank, const long* dims, const float* input, const float* grad_output, float* grad_input); }
#include "test_identity_module.h" void test_identity_module_updateOutput(const int rank, const long* dims, const float* input, float* output) { assert(rank == 5); const int batch_size = dims[0]; const int channels = dims[1]; const int height = dims[2]; const int width = dims[3]; const int depth = dims[4]; for (int b = 0; b < batch_size; b++) { for (int c = 0; c < channels; c++) { for (int h = 0; h < height; h++) { for (int w = 0; w < width; w++) { for (int d = 0; d < depth; d++) { output[(((b*channels + c)*height + h)*width + w)*depth + d] = input[(((b*channels + c)*height + h)*width + w)*depth + d]; } } } } } } void test_identity_module_updateGradInput(const int rank, const long* dims, const float* input, const float* grad_output, float* grad_input) { assert(rank == 5); const int batch_size = dims[0]; const int channels = dims[1]; const int height = dims[2]; const int width = dims[3]; const int depth = dims[4]; for (int b = 0; b < batch_size; b++) { for (int c = 0; c < channels; c++) { for (int h = 0; h < height; h++) { for (int w = 0; w < width; w++) { for (int d = 0; d < depth; d++) { grad_input[(((b*channels + c)*height + h)*width + w)*depth + d] = grad_output[(((b*channels + c)*height + h)*width + w)*depth + d]; } } } } } }
Note that both the forward path and the backward path, corresponding to the updateOutput
and updateGradInput
methods of any Torch nn
module, are defined in the header file. Additionally, the data is passed as simple arrays; therefore, we always pass the rank of the tensor and its size (as long
array) in addition to the actual data. This illustrates that the code is completely independent of any C library (including the Torch core); still it allows to include custom libraries for implementing complex operations (for example, Eigen for matrix computation).
The CUDA template is rather similar — the header file is even identical:
extern "C" { void test_identity_module_updateOutput(const int rank, const long* dims, const float* input, float* output); void test_identity_module_updateGradInput(const int rank, const long* dims, const float* input, const float* grad_output, float* grad_input); }
#include "cuda_helper.h" #include "test_identity_module.h" __global__ void kernel_identity(const float* d_input, float* d_output) { //const int batch_size = gridDim.x; const int channels = gridDim.y; const int height = blockDim.x; const int width = blockDim.y; const int depth = blockDim.z; const int b = blockIdx.x; const int c = blockIdx.y; const int h = threadIdx.x; const int w = threadIdx.y; const int d = threadIdx.z; d_output[(((b*channels + c)*height + h)*width + w)*depth + d] = d_input[(((b*channels + c)*height + h)*width + w)*depth + d]; } void test_identity_module_updateOutput(const int rank, const long* dims, const float* d_input, float* d_output) { assert(rank == 5); const int batch_size = dims[0]; const int channels = dims[1]; const int height = dims[2]; const int width = dims[3]; const int depth = dims[4]; dim3 grid(batch_size, channels, 1); dim3 block(height, width, depth); kernel_identity<<<grid, block>>>(d_input, d_output); } void test_identity_module_updateGradInput(const int rank, const long* dims, const float* d_input, const float* d_grad_output, float* d_grad_input) { assert(rank == 5); const int batch_size = dims[0]; const int channels = dims[1]; const int height = dims[2]; const int width = dims[3]; const int depth = dims[4]; dim3 grid(batch_size, channels, 1); dim3 block(height, width, depth); kernel_identity<<<grid, block>>>(d_grad_output, d_grad_input); }
Here, the identity operation in simply parallelized; in particular, the corresponding kernel, kernel_identity
, implements the identity operation per tensor element. Note, however, that the function definitions are completely identical to the C case.
The above templates can easily be compiled without any dependencies except for CUDA; for the C part, this is done in CMake using
add_library(cpu SHARED test_module.cpp test_identity_module.cpp)
Note that the library is comptiled as share dobject. For CUDA, it is important to check which compute capability is available (for example, check NVidia's documentation). Then, the library can similarly be compiled into a shared library:
find_package(CUDA REQUIRED) list(APPEND CUDA_NVCC_FLAGS "-arch=sm_20;-std=c++11;-O2;-DVERBOSE") SET(CUDA_PROPAGATE_HOST_FLAGS OFF) cuda_add_library(gpu SHARED test_identity_module.cu) target_link_libraries(gpu ${CUDA_LIBRARIES})
LUA Template
The LUA template mainly consists of using the ffi package to locate and load the compiled shared libraries:
require('os') local ffi = require('ffi') -- Will contain all C modules later ... cpu = {} ffi.cdef[[ void test_identity_module_updateOutput(const int rank, const long* dims, const float* input, float* output); void test_identity_module_updateGradInput(const int rank, const long* dims, const float* input, const float* grad_output, float* grad_input); ]] --- Get the script path; -- this assumes that the LUA modules are contained in ../th/. -- @return script path local function scriptPath() local str = debug.getinfo(2, "S").source:sub(2) return str:match("(.*/)") end -- The path to the shared library. local libname = scriptPath() .. '../cpp/cpu/build/libcpu.so' local found = pcall(function () cpu = ffi.load(libname) end) if found then print('[Lib] found ' .. libname) else print('[Info] could not find CPU module, tried ' .. libname) print('[Info] will continue without CPU module') cpu = false os.exit() end
This template basically performs two steps: first, the method definitions are provided — these can be copied directly from the corresponding header file; second, the shared library needs to be located and loaded. If both steps are successful, the implementations are available in the table cpu
. This means that the implementations can now be called using cpu.test_identity_module_updateOutput
.
Torch Module
The nn
module follows the general template of defining the updateOutput
and updateGradInput
:
--- @class TestIdentityModule local TestIdentityModule, TestIdentityModuleParent = torch.class('nn.TestIdentityModule', 'nn.Module') --- Initialize. function TestIdentityModule:__init() end --- Compute forward pass. -- @param input layer input -- @param output function TestIdentityModule:updateOutput(input) self.output = input:clone() cpu.test_identity_module_updateOutput(5, input:size():data(), input:data(), self.output:data()) return self.output end --- Compute the backward pass. -- @param input original input -- @param gradOutput gradients of top layer -- @return gradients with respect to input function TestIdentityModule:updateGradInput(input, gradOutput) self.gradInput = gradOutput:clone() cpu.test_identity_module_updateGradInput(5, input:size():data(), input:data(), gradOutput:data(), self.gradInput:data()) return self.gradInput end
The tensors are passed to the C implementation by accessing the underlying arrays using tensor:data()
. This works for any Torch tensor, allowing to pass the size of a tensor as tensor:size():data()
.
To support both C and CUDA implementations, we can optionally check the input tensors:
if input:type() == 'torch.FloatTensor' then cpu.test_identity_module_updateGradInput(5, input:size():data(), input:data(), gradOutput:data(), self.gradInput:data()) elseif input:type() == 'torch.CudaTensor' then gpu.test_identity_module_updateGradInput(5, input:size():data(), input:data(), gradOutput:data(), self.gradInput:data()) end
which assumes the CUDA library to be loaded into gpu
as done for the C implementations.
Conclusion
Overall, the provided template allows to easily implement Torch modules in C and CUDA without any additional dependencies or knowledge about the Torch core. Furthermore, the template is very flexible as external libraries can easily be integrated; this allows to re-use existing implementations of complex operations. Finally, the Torch module is very efficient as the C/CUDA implementations can directly access the underlying data of Torch tensors. This allows to perform data management — in large parts — in LUA; external libraries (for example, Eigen) may also directly access this data without the need of needlessly copying large tensors.