Introduction
PyTorch has become a standard tool for deep learning research and development. Even uncommon tensor operations or neural network layers can easily be implemented using the variety of operations provided by PyTorch. However, for some special operations, it might make sense to resort to efficient C and CUDA implementations. In this article, I want to show how PyTorch can be extended easily using CFFI for CPU operations and CuPy for GPU operations. As example, I will show how to implement a tensor operation computing the element-wise, bit-level hamming distance between two tensors.
The code for this article is available on GitHub and can easily be extended and adapted:
Code on GitHubSetup
I will use PyThon's C Foreign Function Interface (CFFI) to implement the hamming distance between tensors of data type torch.int32
on CPUs. In order to support GPU computation, the hamming distance can also be implemented as CUDA kernel. The CUDA kernel can be imported using CuPy. Both CFFI and CuPy can easily be installed, for example, using pip install
. For CuPy, however, the installation needs to fit the used CUDA version (as also necessary for PyTorch). Detailed installation instructions can be found here.
Bit Hamming Distance in C
A simple implementation of the bit-level hamming distance between two 32-bit integers in C is provided below:
Listing 1: Hamming distance between to integers.
int a = 15; int b = 19; int dist = 0; int x = a ^ b; while(x != 0) { x = x & (x-1); dist++; }
In order to compute the hamming distance element-wise between two tensors, Listing 1 can be wrapped in a simple loop. The resulting function, as in Listing 2, expects three int arrays as input: the first input tensor, the second input tensor, and the output tensor that will be filled with the hamming distances. All arrays are assumed to be of the same length n
:
Listing 2: Element-wise hamming distance between two arrays of integers. Optionally, OpenMP can be used to speed up computation.
void cffi_int32hammingdistance( const int n, const int* a, const int* b, int* dist ) { #pragma omp parallel for for (int elem_idx = 0; elem_idx < n; elem_idx++) { dist[elem_idx] = 0; int x = a[elem_idx] ^ b[elem_idx]; while(x != 0) { x = x & (x-1); dist[elem_idx]++; } } }
Using CFFI, this function can directly be used to operate on NumPy arrays or PyTorch tensors. To this end, the implementation can be saved, for example, in cffi.c
, and the corresponding header file in cffi.h
.
Bit Hamming Distance CUDA Kernel
The same algorithm as outlined in Listing 1 can easily be put into a CUDA kernel:
Listing 3: CUDA kernel for computing the hamming distance between two arrays of integers. This is essentially Listing 1 for the array element determined by the kernel block/dim/id.
extern "C" __global__ void cupy_int32hammingdistance( const int n, const int* a, const int* b, int* dist ) { int elem_idx = blockIdx.x * blockDim.x + threadIdx.x; if (elem_idx >= n) { return; } int x = a[elem_idx] ^ b[elem_idx]; while(x != 0) { x = x & (x-1); dist[elem_idx]++; } }
CuPy only requires the kernel; and the kernel is not required to be stored in a separate code file. Instead, it can be provided to CuPy as a string in Python.
Putting Everything Together
For assembling all parts, I will use a simple modular structure separating the actual implementation — using CFFI in cffi/
or CuPy in cupy.py
— from the high-level method in torch.py
:
common/ - __init__.py - cffi/ |- cffi.h |- cffi.c |- __init__.py - cupy.py - torch.py
CPU Implementation: First, I will put together the CPU implementation, meaning cffi.c
and cffi.h
. For simplicity, these can be put into their own directory and the Python interface will be defined in the corresponding __init__.py
:
Listing 4: Using CFFI, the C code can be compiled on-the-fly and accessed within Python directly. See the comments for details.
import os import cffi import inspect ffi = cffi.FFI() debug = False use_openmp = True # 1. # This will be the directory of this __init__.py file. # In order settings, it might make sense to specify this in a configuration file # or hard-code them. The directory is only required to find cffi.h and cffi.c. directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) # 2. # Read the header file: with open('%s/cffi.h' % directory) as my_header: ffi.cdef(my_header.read()) # 3. # Read and compile the c file: # Compilation can be done using the usual GCC flags. # Here, a debug mode and an optimized mode with and without OpenMP are provided. with open('%s/cffi.c' % directory) as my_source: if debug: ffi.set_source( '_cffi', my_source.read(), extra_compile_args=[ '-pedantic', '-Wall', '-g', '-O0'], ) # -ffast-math assumes there are no nans or infs! # -O3 includes -ffast-math! # https://stackoverflow.com/questions/22931147/stdisinf-does-not-work-with-ffast-math-how-to-check-for-infinity else: if use_openmp: ffi.set_source( '_cffi', my_source.read(), extra_compile_args=['-fopenmp', '-D use_openmp', '-O3','-march=native'], extra_link_args=['-fopenmp'], ) else: ffi.set_source('_cffi', my_source.read(), extra_compile_args=['-O3','-march=native'], ) # 4. # Compile using the parameters above. ffi.compile() #ffi.compile(verbose=True) # 5. # This simply imports all compiled functions and makes them available in this module. from _cffi import *
The code takes care of compiling the function in Listing 2 and interfacing it to Python. Afterwards, the function can be accessed through cffi.lib.cffi_int32hammingdistance
where Listing 4 corresponds to cffi/__init__.py
. I will go through the essential steps in detail in the following:
- The absolute path to the directory of the
__init__.py
file is obtained. This is necessary in order to locate header and implementation files to be compiled. Depending on the setup, this can also be solved differently, for example, by hard-coding absolute paths. - The header file is read such that CFFI is aware of the function definition.
- The source file is read and compilation options are set. Here, the code allows several different settings including a debug setting without optimization and a setting supporting OpenMP.
- After knowing the function definition (through the header file), determining compilation options and reading the source code, CFFI can compile everything.
- Finally, all compiled functions are imported such that they can be accessed more easily later.
GPU Implementation: For the CuPy part, I will create a separate module in cupy.py
:
Listing 5: Similar to CFFI, CuPy also allows to compile a CUDA kernel on-the-fly. See the comments for details.
import torch try: import cupy # 1. This will be used to call a kernel with source code provided as Python string. @cupy.util.memoize(for_each_device=True) def cunnex(strFunction): return cupy.cuda.compile_with_cache(globals()[strFunction]).get_function(strFunction) except ImportError: print("CUPY cannot initialize, not using CUDA kernels") class Stream: ptr = torch.cuda.current_stream().cuda_stream # 2. The kernel as Python string from Listing 3 cupy_int32hammingdistance = ''' extern "C" __global__ void cupy_int32hammingdistance( const int n, const int* a, const int* b, int* dist ) { int elem_idx = blockIdx.x * blockDim.x + threadIdx.x; if (elem_idx >= n) { return; } int x = a[elem_idx] ^ b[elem_idx]; while(x != 0) { x = x & (x-1); dist[elem_idx]++; } } '''
The CuPy interface is even simpler:
- This utility function will take care of compilation and interfacing the function. As argument, the function expects the name of a variable that holds the actual CUDA kernel source code.
- The source code is saved in this variable, not in a separate source file.
Finally, in torch.py
, both implementations will be combined. The result is a high-level function int32_hamming_distance
expecting two torch.int32
tensors as input. The function will automatically allocate memory for the output and call the appropriate interface depending on whether the inputs are on GPU memory or not. For determining whether a tensor is on GPU memory, a simple is_cuda
function is provided (not shown here):
Listing 6: Putting CFFI and CuPy implementations together in a high-level method int32_hamming_distance
that automatically uses the CPU or GPU implementation depending on the input tensors.
def int32_hamming_distance(a, b): """ Bit-wise hamming distance. :param a: first tensor :type a: torch.Tensor :param b: first tensor :type b: torch.Tensor :return: hamming distance :rtype: torch.Tensor """ if not a.is_contiguous(): a.contiguous() assert (a.dtype == torch.int32) cuda = is_cuda(a) if not b.is_contiguous(): b.contiguous() assert (b.dtype == torch.int32) assert is_cuda(b) is cuda assert len(a.shape) == len(a.shape) for d in range(len(a.shape)): assert a.shape[d] == b.shape[d] # 1. Initialize output tensor to hold the element-wise hamming distances. dist = a.new_zeros(a.shape).int() n = dist.nelement() if cuda: # 2. Call the cupy implementation using the helper function cupy.cunnex. # The function returned by cupy.cunnex expects, among others, the grid/block division to use # and the kernel arguments; here a.data_ptr() will return the point to the tensor a # and is assumed to be on GPU memory. cupy.cunnex('cupy_int32hammingdistance')( grid=tuple([int((n + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[n, a.data_ptr(), b.data_ptr(), dist.data_ptr()], stream=cupy.Stream ) else: # 3. For CFFI, the inputs have to be cast to the target C equivalents using cffi.ffi.cast. # Afterwards, the C function can be called like a regular Python function using the converted arguments. _n = cffi.ffi.cast('int', n) _a = cffi.ffi.cast('int*', a.data_ptr()) _b = cffi.ffi.cast('int*', b.data_ptr()) _dist = cffi.ffi.cast('int*', dist.data_ptr()) cffi.lib.cffi_int32hammingdistance(_n, _a, _b, _dist) return dist
Essentially, the function creates the output tensor to hold the element-wise hamming distances and then calls the CuPy or CFFI interface depending on whether the inputs are stored in GPU memory or not:
- Create the output tensor, also on GPU if required. It will also be a
torch.int32
tensor of same size. - The CuPy implementation is called through the helper function
cupy.cunnex
that fetches the corresponding source code, compiles it (if not cached) and returns the corresponding function. The returned function expects the inputs — here,a.data_ptr()
is used to access the pointer underlying the given tensor — and the block/dim/id subdivision for the kernel. - Calling the CFFI implementation, instead, requires to explicitly cast the inputs to the equivalent C types. Then, the function can be called like a regular Python function.
Conclusion
Overall, this article shows that implementing custom PyTorch operations in C and CUDA, supporting CPU and GPU computation, can be very simple using CFFI and CuPy. Furthermore, I provided a simple framework that can easily be extended to custom operations.