IAM

ARTICLE

Implementing Tensorflow Operations in C++ — Including Gradients

In this article, I discuss a simple Tensorflow operation implemented in C++. While the example mostly builds upon the official documentation, it includes trainable parameters and the gradient computation is implemented in C++, as well. As such, the example is slightly more complex compared to the simple ZeroOut operation discussed in the documentation.

Although Tensorflow provides a thorough tutorial on how to add new operations, the provided example is rather simple and gradients are meant to be implemented in Python. However, in many practical cases, operations get more complex and involve parameters that are optimized. In order to get started implementing complex operations for Tensorflow in C++, I implemented a simple linear operation for neural networks (i.e. a matrix-vector multiplication operation, sometimes also referred to as inner product layer). The example includes both trainable parameters and gradients implemented in C++ instead of Python.

The example is not very general and should not be used in actual production code. Instead, it is meant to complement the simple example provided in the documentation. The code is available on GitHub:

Example on GitHub

Forward Operation

The listing below shows the implementation of the forward operation, i.e. given an input vector and a weight matrix the matrix-vector product is calculated. The implemented is saved to inner_product.cc in an arbitrary directory:

/// \file inner_product.cc
/// \author David Stutz
/// \brief Implementation of a inner product (i.e. fully connected layer)
/// operation in Tensorflow.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("InnerProduct")
  .Input("input: float")
  .Input("weights: float")
  .Output("inner_product: float")
  .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    shape_inference::ShapeHandle input_shape;
    TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input_shape));

    shape_inference::ShapeHandle weight_shape;
    TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &weight_shape));
    
    shape_inference::DimensionHandle output_rows = c->Dim(weight_shape, 0);
  
    shape_inference::DimensionHandle input_rows = c->Dim(input_shape, 0);
    shape_inference::DimensionHandle weight_cols = c->Dim(weight_shape, 1);
    shape_inference::DimensionHandle merged;
    TF_RETURN_IF_ERROR(c->Merge(input_rows, weight_cols, &merged));

    c->set_output(0, c->Matrix(output_rows, 1));
    return Status::OK();
  });

/// \brief Implementation of an inner product operation.
/// \param context
/// \author David Stutz
class InnerProductOp : public OpKernel {
public:
  /// \brief Constructor.
  /// \param context
  explicit InnerProductOp(OpKernelConstruction* context) : OpKernel(context) {
    
  }
  
  /// \brief Compute the inner product.
  /// \param context
  void Compute(OpKernelContext* context) override {
    
    // some checks to be sure ...
    DCHECK_EQ(2, context->num_inputs());
    
    // get the input tensor
    const Tensor& input = context->input(0);
    
    // get the weight tensor
    const Tensor& weights = context->input(1);
    
    // check shapes of input and weights
    const TensorShape& input_shape = input.shape();
    const TensorShape& weights_shape = weights.shape();
    
    // check input is a standing vector
    DCHECK_EQ(input_shape.dims(), 2);
    DCHECK_EQ(input_shape.dim_size(1), 1);
    
    // check weights is matrix of correct size
    DCHECK_EQ(weights_shape.dims(), 2);
    DCHECK_EQ(input_shape.dim_size(0), weights_shape.dim_size(1));
    
    // create output shape
    TensorShape output_shape;
    output_shape.AddDim(weights_shape.dim_size(0));
    output_shape.AddDim(1);
            
    // create output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    
    // get the corresponding Eigen tensors for data access
    auto input_tensor = input.matrix<float>();
    auto weights_tensor = weights.matrix<float>();
    auto output_tensor = output->matrix<float>();
    
    for (int i = 0; i < output->shape().dim_size(0); i++) {
      output_tensor(i, 0) = 0;
      for (int j = 0; j < weights.shape().dim_size(1); j++) {
        output_tensor(i, 0) += weights_tensor(i, j)*input_tensor(j, 0);
      }
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("InnerProduct").Device(DEVICE_CPU), InnerProductOp);

Slightly following the documentation, the implementations contains the following important parts:

  • Beginning in line 13 the interface of the operation is defined; this includes defining input and output attributes as well as a function for shape inference. As discussed in the official documentation, attributes can also added here.
  • The Compute method beginning in line 48 contains the actual implementation of the inner product operation.
  • For simplicity, the operation is implemented directly beginning in line 81. However, there should be capabilities for an easier implementation provided by Tensorflow — I just did not find them. The tensor contents are accessed directly via the underlying Eigen tensors. Thanks to C++11's auto, the types do not need be known in detail and the tensors can be accessed via tensorflow_tensor.vec(), tensorflow_tensor.matrix() or in general tensorflow_tensor.tensor().
  • In line 94, the operation is registered, allowing to set specific constraints such as the device the operation runs on. For simplicity, the implementation runs on the CPU.

Gradient Operation

In the documentation, the operation gradients are implemented in Python. To be able to implement gradients in C++, the gradient operation is defined as a completely separate operation saved in inner_product_grad.cc:

As isunchy mentioned in the comments, the matmul implementation in TensorFlow I was not able to find FastGemmFunctor.

/// \file inner_product_grad.cc
/// \author David Stutz
/// \brief Implementation of the gradient of a inner product operation, see
/// inner_product.cc.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("InnerProductGrad")
  .Input("grad: float32")
  .Input("input: float32")
  .Input("weights: float32")
  .Output("grad_input: float32")
  .Output("grad_weights: float32");

/// \brief Implementation of an inner product gradient operation.
/// Note that this operation is used in Python to register the gradient as
/// this is not possible in C*+ right now.
/// \param context
/// \author David Stutz
class InnerProductGradOp : public OpKernel {
public:
  /// \brief Constructor.
  /// \param context
  explicit InnerProductGradOp(OpKernelConstruction* context) : OpKernel(context) {
    
  }
  
  /// \brief Compute the inner product gradients.
  /// \param context
  void Compute(OpKernelContext* context) override {
    
    // output and grad is provided as input
    DCHECK_EQ(3, context->num_inputs());

    // get the gradient tensor
    const Tensor& grad = context->input(0);
    
    // get the original input tensor
    const Tensor& input = context->input(1);
    
    // get the weight tensor
    const Tensor& weights = context->input(2);
    
    // create input shape (inferred from the additional attribute `n`)
    TensorShape input_shape = input.shape();
    TensorShape weights_shape = weights.shape();
    
    DCHECK_EQ(input_shape.dim_size(0), weights_shape.dim_size(1));
    DCHECK_EQ(weights_shape.dim_size(0), grad.shape().dim_size(0));
    
    // create output tensors
    Tensor* grad_input = NULL;
    Tensor* grad_weights = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &grad_input));
    OP_REQUIRES_OK(context, context->allocate_output(1, weights_shape, &grad_weights));
    
    // get the Eigen tensors for data access
    auto grad_tensor = grad.matrix<float>();
    auto weights_tensor = weights.matrix<float>();
    auto input_tensor = input.matrix<float>();
    auto grad_input_tensor = grad_input->matrix<float>();
    auto grad_weights_tensor = grad_weights->matrix<float>();
    
    // TODO couldn't really find basic MatMul operations and how to use them,
    // so doing the stuff manually, should be fine as example.
    // Update: see note above, matmul is implemented in FastGemmFunctor
    for (int i = 0; i < weights_shape.dim_size(1); i++) {
      grad_input_tensor(i, 0) = 0;
      for (int j = 0; j < grad.shape().dim_size(0); j++) {
        grad_input_tensor(i, 0) += grad_tensor(j, 0)*weights_tensor(j, i);
      }
    }
    
    for (int i = 0; i < weights_shape.dim_size(0); i++) {
      for (int j = 0; j < weights_shape.dim_size(1); j++) {
        grad_weights_tensor(i, j) = grad_tensor(i, 0)*input_tensor(j, 0);;
      }
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("InnerProductGrad").Device(DEVICE_CPU), InnerProductGradOp);

The listing above is mostly analogously to the forward operation except for a minor difference:

  • Beginning in line 11, the interface of the operation is defined, taking the original input, the weights and the gradients from the top node in the computation graph (e.g. the top layer in neural network terms) as input, and defining the gradients with respect to the input and the weights as outputs. The shape inference function is omitted.

Given the gradient operation, it needs to be registered and associated with the forward operation. This is done in Python, specifically in _inner_product_grad.py:

#!/usr/bin/env python3
"""
Gradients for inner product.
"""

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
inner_product_grad_module = tf.load_op_library('build/libinner_product_grad.so')

@ops.RegisterGradient("InnerProduct")
def _inner_product_grad_cc(op, grad):
    """
    The gradient for `inner_product` using the operation implemented in C++.
    
    :param op: `inner_product` `Operation` that we are differentiating, which we can use
        to find the inputs and outputs of the original op.
    :param grad: gradient with respect to the output of the `inner_product` op.
    :return: gradients with respect to the input of `inner_product`.
    """
    
    return inner_product_grad_module.inner_product_grad(grad, op.inputs[0], op.inputs[1])    

It becomes clear that up to now, the forward operation and the gradient operation where completely independent from each other. Also note how the InnerProductGrad operation is imported in Python; this requires to know the location of the corresponding shared library (i.e. the .so file). Building using CMake is discussed in the following section.

Building

Update. As alternative to CMake, a bash script for building can be found in the comments. Also note that the provided CMake file might not work for Tensorflow 1.11; a workaround can — again — be found in the comments.

As I am most comfortable with CMake, I was relieved to find out that Bazel is not mandatory when implementing new operations. The following listing shows a simple CMakeLists.txt doing the job:

cmake_minimum_required(VERSION 2.8)

# get tensorflow include dirs, see https://www.tensorflow.org/how_tos/adding_an_op/
execute_process(COMMAND python3 -c "import tensorflow; print(tensorflow.sysconfig.get_include())" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS)

# C++11 required for tensorflow
set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}")

include_directories(${Tensorflow_INCLUDE_DIRS})
add_library(inner_product SHARED inner_product.cc)

include_directories(${Tensorflow_INCLUDE_DIRS})
add_library(inner_product_grad SHARED inner_product_grad.cc)

There are a few things to note:

  • Note line 5, where tensorflow.sysconfig.get_include() is used to get the include directories of the Tensorflow installation — this is also detailed in the documentation.
  • In lines 10 and 13, the operations are compiled as a shared libraries.

Both operations (which are put together in Python) can be compiled using:

$ mkdir build
$ cd build
$ cmake ..
$ make

Of course, this assumes that all mentioned files are placed in the same directory. The shared libraries will then be found in the build directory: build/libinner_product.so and build/libinner_product_grad.so.

Tests

In order to illustrate the usage of the operation, both the forward and backward pass, some unit tests can be found in the listing below:

#!/usr/bin/env python3
"""
Tests for the inner product Tensorflow operation.
"""

import unittest
import numpy as np
import tensorflow as tf
import _inner_product_grad
inner_product_module = tf.load_op_library('build/libinner_product.so')

class InnerProductOpTest(unittest.TestCase):
    def test_raisesExceptionWithIncompatibleDimensions(self):
        with tf.Session(''):
            with self.assertRaises(ValueError):
                inner_product_module.inner_product([1, 2], [[1, 2], [3, 4]]).eval()
            with self.assertRaises(ValueError):
                self.assertRaises(inner_product_module.inner_product([1, 2], [1, 2, 3, 4]).eval(), ValueError)
            with self.assertRaises(ValueError):
                self.assertRaises(inner_product_module.inner_product([1, 2, 3], [[1, 2], [3, 4]]).eval(), ValueError)
            
    def test_innerProductHardCoded(self):
        with tf.Session(''):
            result = inner_product_module.inner_product([[1], [2]], [[1, 2], [3, 4]]).eval()
            self.assertEqual(result.shape[0], 2)
            self.assertEqual(result[0], 5)
            self.assertEqual(result[1], 11)
    
    def test_innerProductGradientXHardCoded(self):
        with tf.Session('') as sess:
            x = tf.placeholder(tf.float32, shape = (2))
            W = tf.constant(np.asarray([[1, 2], [3, 4]]).astype(np.float32))
            
            Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1]))
            Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W)
            
            grad_x_tf = tf.gradients(Wx_tf, x)
            grad_x_inner_product = tf.gradients(Wx_inner_product, x)
            
            gradient_tf = sess.run(grad_x_tf, feed_dict = {x: np.asarray([1, 2]).astype(np.float32)})
            gradient_inner_product = sess.run(grad_x_inner_product, feed_dict = {x: np.asarray([1, 2]).astype(np.float32)})
            
            self.assertEqual(gradient_tf[0][0], gradient_inner_product[0][0])
            self.assertEqual(gradient_tf[0][1], gradient_inner_product[0][1])
    
    def test_innerProductGradientWHardCoded(self):
        with tf.Session('') as sess:
            x = tf.constant(np.asarray([1, 2]).astype(np.float32))
            W = tf.placeholder(tf.float32, shape = (2, 2))
            
            Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1]))
            Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W)
            
            grad_W_tf = tf.gradients(Wx_tf, W)
            grad_W_inner_product = tf.gradients(Wx_inner_product, W)
            
            gradient_tf = sess.run(grad_W_tf, feed_dict = {W: np.asarray([[1, 2], [3, 4]]).astype(np.float32)})
            gradient_inner_product = sess.run(grad_W_inner_product, feed_dict = {W: np.asarray([[1, 2], [3, 4]]).astype(np.float32)})
            
            self.assertEqual(gradient_tf[0][0][0], gradient_inner_product[0][0][0])
            self.assertEqual(gradient_tf[0][0][1], gradient_inner_product[0][0][1])
            self.assertEqual(gradient_tf[0][1][0], gradient_inner_product[0][1][0])
            self.assertEqual(gradient_tf[0][1][1], gradient_inner_product[0][1][1])
    
    def test_innerProductRandom(self):
        with tf.Session(''):
            n = 4
            m = 5
            
            for i in range(100):
                x_rand = np.random.randint(10, size = (n, 1))
                W_rand = np.random.randint(10, size = (m, n))
                result_rand = np.dot(W_rand, x_rand)
                
                result = inner_product_module.inner_product(x_rand, W_rand).eval()
                np.testing.assert_array_equal(result, result_rand)
    
    def test_innerProductGradientXRandom(self):
        with tf.Session('') as sess:
            n = 4
            m = 5
            
            x = tf.placeholder(tf.float32, shape = (n))
            W = tf.placeholder(tf.float32, shape = (m, n))
            
            Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1]))
            Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W)
            
            grad_x_tf = tf.gradients(Wx_tf, x)
            grad_x_inner_product = tf.gradients(Wx_inner_product, x)
            
            for i in range(100):
                x_rand = np.random.randint(10, size = (n))
                W_rand = np.random.randint(10, size = (m, n))
                
                gradient_tf = sess.run(grad_x_tf, feed_dict = {x: x_rand, W: W_rand})
                gradient_inner_product = sess.run(grad_x_inner_product, feed_dict = {x: x_rand, W: W_rand})
                
                np.testing.assert_array_equal(gradient_tf, gradient_inner_product)
                
    def test_innerProductGradientWRandom(self):
        with tf.Session('') as sess:
            n = 4
            m = 5
            
            x = tf.placeholder(tf.float32, shape = (n))
            W = tf.placeholder(tf.float32, shape = (m, n))
            
            Wx_tf = tf.matmul(W, tf.reshape(x, [-1, 1]))
            Wx_inner_product = inner_product_module.inner_product(tf.reshape(x, [-1, 1]), W)
            
            grad_W_tf = tf.gradients(Wx_tf, W)
            grad_W_inner_product = tf.gradients(Wx_inner_product, W)
            
            for i in range(100):
                x_rand = np.random.randint(10, size = (n))
                W_rand = np.random.randint(10, size = (m, n))
                
                gradient_tf = sess.run(grad_W_tf, feed_dict = {x: x_rand, W: W_rand})
                gradient_inner_product = sess.run(grad_W_inner_product, feed_dict = {x: x_rand, W: W_rand})
                
                np.testing.assert_array_equal(gradient_tf, gradient_inner_product)
                  
                
if __name__ == '__main__':
    unittest.main()

Some comments:

  • Note that in line 10, only the forward operation — libinner_product.so — is imported. Remember that the backward operation was registered in _inner_product_grad.py which is imported in line 9 and itself imports libinner_product_grad.so.
  • The test beginning in line 15 illustrates some of the cases that are caught by the shape inference function defined for the forward pass. As of my experience, checks (e.g. using DCHECK_XX) inside the Compute function are handled differently than checks in the shape inference function.
  • The test starting in line 22 illustrates a simple forward pass.
  • The remaining two tests illustrate gradient computation with respect to both the input and the weights.

Conclusion

The presented example is simple enough to demonstrate the general idea of adding new operations in Tensorflow. Still, it also includes some more complex cases — such as trainable parameters and the gradient operation implemented in C++ — compared to the official documentation. Overall, Tensorflow tries to make custom operations as easy as possible. Nevertheless, the internal mechanics of Tensorflow are hard to understand — which will hopefully get easier with improved documentation and comments within the Tensorflow core.

What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.