IAM

ARTICLE

A Note on Extending Tensorflow, PyTorch, Theano and Torch for Deep Learning on Custom Data Structures

Many recent deep learning frameworks such as Tensorflow, PyTorch, Theano or Torch are based on dense tensors. However, deep learning on non-tensor data structures is also interesting – especially for sparse, three-dimensional data. This article summarizes some of my experiences regarding deep learning on custom data structures in the mentioned libraries.

I spent the last few days investigating how to extend popular deep learning libraries such as Tensorflow, PyTorch, Theano or Torch with custom (non-tensor) data structures. This article is meant as a brief note on what I found out. Also let me know your thoughts in the comments.

All the mentioned libraries are essentially built upon graphs of tensor operations. If these tensor operations support differentiation (i.e. backward passes are implemented), they lend themselves to large-scale deep learning, usually utilizing GPU acceleration. As such, these libraries are heavily built upon multi-dimensional tensors, usually dense (although some libraries offer support of sparse tensors).

However, deep learning is not only interesting on dense tensors. Especially for sparse data and in higher-dimensions, other data structures as well as different operations become interesting. Some examples include OctNet [], sparse 3d convolutional networks [] or Vote3Deep [] to name some of my recent readings. These all introduce different data structures or some sort of specialized convolution operations.

In the following I want to describe my impression of how (or whether) to extend the mentioned libraries to custom data structures. The motivation is to use one of the libraries as experimental framework and convenient interfacing of C++/CUDA code (in which the data structures are probably implemented) to a simple scripting language such as Python or LUA.

Tensorflow

Tensorflow provides a detailed guide on how to add custom operations: Adding a New Op. Examples can also be found in some of my recent snippets:

However, these example do not tackle the question of how to define custom operations on non-tensor data structures. After spending some days studying Tensorflow's source code (in particular, the core framework), it became clear that Tensorflow is build upon and around Eigen's tensor module. At this point I want to mention that this README gives a short overview of Eigen's tensor module. Overall, I got the impression that it will be very (very) labor-intensive to integrate custom data structures into the core framework directly.

While this sounds pessimistic, there might still be hope. One option might be to express the data structure of interest as set of Tensorflow tensors. Implemented in C/C++, each data structure is inherently built upon arrays — these can easily be wrapped as tensors.

/// Tensor to array conversion.
/// @author David Stutz
template<typename T, int NDIMS>
class TensorConversion {
public:
  
  /// Access the underlying data pointer of the tensor.
  /// @param tensor
  /// @return
  static T* AccessDataPointer(const tensorflow::Tensor &tensor) {
    // (1) get the underlying Eigen tensor; there are also
    // the .matrix and .vector methods for matrix and vector sizes.
    auto tensor_map = tensor.tensor<T, NDIMS>();
    // (2) Get the underlying array from the tensor.
    auto array = tensor_map.data();
    // (3) Return the array.
    return const_cast<T*>(array);
  }
};

The above example illustrates how to access the underlying data array given a tensorflow::Tensor:

  1. The Eigen tensor is accessed using the .tensor method — in general Tensorflow also provides .vector and .matrix methods.
  2. Given the Eigen tensor, the array can be obtained using the .data method.
  3. After casting to the correct type, the corresponding pointer is returned.

Overall, it might be possible to wrap the C++-arrays of custom data structures in Tensorflow tensors. In Python, these can be passed around as regular Tensorflow tensors, while custom Tensorflow operations implement data structure specific operations. However, this idea is only a blueprint and I did not pursue this approach further. Instead, I went on to investigate PyTorch.

PyTorch

PyTorch might be among the youngest deep learning libraries. In contrast to Tensorflow, PyTorch allows to dynamically create computation graphs (which at the point of writing was problematic when using Tensorflow) — check this Quora question for a short discussion.

PyTorch also offers documentation on extending PyTorch, however, documentation is still in an early stage. The example from the documentation discusses a simple linear layer operating on tensors and without any complex details.

linear.py
import torch

# (1) LinearOp will define a custom operation corresponding to
# a fully connected layer; the gradients will also be defined.
class LinearOp(torch.autograd.Function):
    """
    Linear Operation for PyTorch.
    """
    
    # (1.1) The forward pass is easily implemented; the most
    # interesting part is saving the input and the parameters
    # for the backward pass. They can then be accessed
    # via self.saved_tensors.
    def forward(self, input, weight, bias = None):
        # can be accessed using self.saved_tensors
        self.save_for_backward(input, weight, bias)
        
        # compute input*weight + bias
        output = input.mm(weight)
        if bias is not None:
            output += bias.repeat(1, input.size()[0])
        return output

    # (1.2) The backward pass accesses the saved tensors in
    # self.saved_tensors and computes the gradients in a straight
    # forward way.
    def backward(self, grad_output):
        # get saved tensors and initialize gradients
        input, weight, bias = self.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # the checks below are for efficiency and can also be skipped
        if self.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if self.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and self.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

# (2) LinearNN wraps the linear operation into a neural network module.
# Note that the module defines a set of nn.Parameters in the constructor
# corresponding to the trainable weights and the forward pass is
# implemented using the linear operation from above. Differentiation,
# i.e. the backward pass is defined implicitly through the above operation.
class LinearNN(torch.nn.Module):
    """
    Linear neural network module based on the operation defined above.
    """

    def __init__(self, input_features, output_features, bias = True):
        torch.nn.Module.__init__(self)

        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter will automatically register the variable as the module's
        # parameter; if optional parameters are not used, they should be registered as None
        self.weight = torch.nn.Parameter(torch.Tensor(input_features, output_features))
        if bias is not None:
            self.bias = torch.nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)

        # initialize weights
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return LinearOp()(input, self.weight, self.bias)
main.py
import numpy
import torch
import linear
import random

# (3) A simple, on-layer network consisting of the linear layer
# defined in linear.py and a softmax output layer.
class Net(torch.nn.Module):
    """
    Network definition, consisting of a linear layer only.
    """

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = linear.LinearNN(6, 2, True)

    def forward(self, x):
        x = self.fc1(x)
        return torch.nn.functional.softmax(x)

if __name__ == '__main__':

    # (4) A toy dataset for classificaiton is created.
    # create a toy dataset.
    N = 1000
    inputs = numpy.zeros((N, 6))
    outputs = numpy.zeros((N, 1))
    for n in range(N):
        outputs[n, 0] = random.randint(0, 1)
        if outputs[n, 0] > 0:
            inputs[n, 0:3] = 1
        else:
            inputs[n, 3:6] = 1

    model = Net()
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

    # (5) The training loop randomly selects some samples, computes a forward pass
    # and uses an optimizer and a loss for training.
    # When used to Torch code (in Lua), this setup might look familiar.
    for t in range(100):
        indices = numpy.arange(N)
        numpy.random.shuffle(indices)
        indices = indices[0:10]

        data = torch.autograd.Variable(torch.Tensor(inputs[indices]))
        target = torch.autograd.Variable(torch.Tensor(outputs[indices]))
        pred = model(data)
        
        # better pass long to loss: https://discuss.pytorch.org/t/problems-with-target-arrays-of-int-int32-types-in-loss-functions/140
        # also target has to be 1D - https://github.com/torch/cutorch/issues/227
        loss = torch.nn.functional.nll_loss(pred, target.resize(10).long())
        loss.backward()
        optimizer.step()

        if t%10 == 0:
            pred = model(torch.autograd.Variable(torch.Tensor(inputs)))
            # use pred.data to get from a torch.autograd.Variable to the underlying Tensor
            accuracy = (pred.data.numpy().argmax(1).reshape((N)) == outputs.reshape((N))).astype(int).sum() / float(N)
            print(accuracy)

The code above shows a simple example adapted from the documentation:

  1. LinearOp defines a differentiable operation implementing a linear, i.e. fully connected, layer:
    1. The forward is straight forward; the only interesting part is that tensors, for example the input, can explicitly be saved for the backward pass. The saved tensors are then accessible through self.saved_tensors.
    2. The backward pass accesses the saved tensors and computes the gradients with respect to input and parameters.
  2. LinearNN wraps the defined linear operation into a neural network module. The important part is the constructor, which defines the trainable parameters. The forward pass simply calls the defined linear operation — this implicitly also defines the backward pass.
  3. In main.py, a simple usage example is shown; the network Net consists of one LinearNN and a softmax layer.
  4. The toy dataset is a simple classification problem.
  5. The training loop looks familiar when coming from Torch (Lua); in each iteration, given random inputs and outputs, the input is passed through the model, a loss is defined and a single optimization step is done.

Unfortunately, the documentation does not go beyond the above example. In particular, it does not discuss more complex modules or different input types and it is difficult to get any detailed information. As I wanted to avoid digging into the source code as I did with Tensorflow, I moved on to check Theano.

Theano

Theano is the only library that explicitly discusses custom data types in its documentation: Making the Double Type. A simple example building upon the document is shown below. However, as also discussed on StackOverflow, it is not trivial to get differentiation working for custom types. Therefore, the below example does not work completely. Any help in the comments is appreciated!

double.py
import theano

# (1) Defines the double type, extending theano.gof.Type.
class Double(theano.gof.Type):
    """
    Double type for theano.
    """

    dtype = 'float64'

    # (1.1) Each type needs a filter, which is used to cast a given type to
    # the defined type. The below method looks as in the documentation.
    def filter(self, value, strict = False, allow_downcast = None):
        """
        Cast to double.
        """

        if strict:
            # we need to return a type, but if the value is incompatible raise an exception
            if isinstance(value, float):
                return value
            else:
                raise TypeError('Expected a float!')
        elif allow_downcast:
            return float(value)
        else:
            value_float = float(value)
            if value_float == value:
                return value_float
            else:
                raise TypeError('The double type cannot be accurately represent %s of type %s' % (value, type(value)))
    
    # (1.2) The remaining methods are for checking for equality
    # and are left-overs from the discussion in 
    # https://stackoverflow.com/questions/41858327/how-to-define-custom-theano-types-allowing-differentiation
    def values_eq_approx(self, value_a, value_b, tolerance = 1e-6):
        """
        Check whether values are approximately equal.
        """

        return abs(value_a - value_b) / (abs(value_a) + abs(value_b)) < tolerance

    def ones_like(self, model, dtype):
        return double('1')

    # using this method causes the type to not be hashable anymore?! - strange ...
    #def __eq__(self, value):
        """
        To allow using ==.
        """

        #return type(self) is Double and type(value) is Double

# is already used in the operations below
double = Double()

# (2) The idea is to define an add operation on the
# introduced double type, however, the gradient computation
# is not as simple as expected ...
class DoubleAddOp(theano.Op):
    """
    Add doubles.
    """

    __props__ = ()

    # (2.1) The make_node methods defines input and output,
    # it follows the examples form the documentation.
    def make_node(self, x, y):
        """
        Define output variables.
        """

        # check input types
        if isinstance(x, (int, float)):
            x = theano.gof.Constant(double, x)
        if isinstance(y, (int, float)):
            y = theano.gof.Constant(double, y)

        if x.type != double or y.type != double:
            raise TypeError('DoubleAddOp only works on doubles.')

        return theano.gof.Apply(self, [x, y], [double()])

    # (2.2) The perform method computes the output, i.e.
    # performs addition.
    def perform(self, node, inputs, output_storage):
        """
        Perform the addition.
        """

        x = inputs[0]
        y = inputs[1]
        z = output_storage[0]
        z[0] = x + y

    def infer_shape(self, node, input_shapes):
        """
        Used to infer the shape before performing the operation.
        """

        return [input_shapes[0]] # see /is/ps2/dstutz/.local/lib/python3.5/site-packages/theano/tensor/opt.py line 1318

    # (2.3) The grad method is supposed to compute the gradients, 
    # but as described on StackOverflow this is not as simple ...
    def grad(self, inputs, output_grads):
        """
        Compute the gradients w.r.t. both inputs.
        """

        return [output_grads[0]*1, output_grads[0]*1]

    def __str__(self):
        """
        Identification.
        """

        return 'DoubleAddOp'

dadd = DoubleAddOp()
main.py
"""
Tests for the double type and its operations.

Also see for discussion: http://stackoverflow.com/questions/41858327/how-to-define-custom-theano-types-allowing-differentiation
"""

import theano
import random
import unittest
from double import double, dadd, dmul

class TestDoubleOps(unittest.TestCase):
    """
    Test operations on doubles.
    """

    # update to the latest Theano if the test cause unclosed files errors
    def test_DoubleAddOpPerform(self):
        x = double('x')
        y = double('y')
        z = dadd(x, y)
        f = theano.function([x, y], z)

        for i in range(100):
            x_value = random.random()
            y_value = random.random()
            self.assertAlmostEqual(f(x_value, y_value), x_value + y_value)

    def test_DoubleAddOpGrad(self):
        x = double('x')
        y = double('y')
        z = dadd(x, y)
        gx = theano.grad(z, x)
        gy = theano.grad(z, y)
        f = theano.function([x, y], [gx, gy])

        for i in range(100):
            x_value = random.random()
            y_value = random.random()

            gx_value, gy_value = f(x_value, y_value)
            self.assertAlmostEqual(gx_value, 1)
            self.assertAlmostEqual(gy_value, 1)

if __name__ == '__main__':
    unittest.main()

The above example follows the documentation to implement the double type and define an operation on this type — this would be exactly what this article is about. Let's see:

  1. The double type extends theano.gof.Type:
    1. The type definition includes a method for casting types to double; this is done in the filter method. The structure closely follows the documentation.
    2. All remaining methods are used to allow equality checking and a left-overs from the discussion on StackOverflow.
  2. DoubleAddOp is intended to define an add operation on the double type:
    1. The make_node method defines inputs and outputs; details are explained in the documentation.
    2. perform represents the actual opertion, i.e. the addition of two doubles.
    3. Finally, grad is supposed to compute the gradients. But as discussed on StackOverflow this is non-trivial.

Also check main.py on usage examples regarding forward and backward pass — although the backward pass still has its problems. Overall, I think that — in the long run — Theano might be a good option for working with custom data structures. However, it might require some time and help to get started. Still, I moved on to Torch.

Torch

Although I tried to avoid Torch — knowing that Lua is not as comfortable as Python —, I also knew that Torch easily allows to use custom data structures for deep learning. A simple example is given below:

-- Small example to test forward and packward passes of custom data structures.

require('math')
require('torch')
require('nn')

-- (1) CustomDataStructure will wrap two torch tensors as x1 and x2.
-- We will then define a nn module on this data structure. 
-- The below definition of the data structure is straight-forward.
--- @class CustomDataStructure
-- This calss will be out simple test data structure.
CustomDataStructure = {}
CustomDataStructure.__index = CustomDataStructure

--- Creates a new CustomDataStructure with 0-vectors of the given size
-- @param n vector size
function CustomDataStructure.create(n)
  local cds = {}
  setmetatable(cds, CustomDataStructure)
  cds.x1 = torch:Tensor(n):fill(0)
  cds.x2 = torch:Tensor(n):fill(0)
  return cds
end

-- (2) CustomLinear will implement a linear, fully connected, layer which
-- computes a linear operation on both x1 and x2 simulatenously using
-- the same parameters. It expects as input a CustomDataStructure and
-- also returns one. Note that the module looks as any other Torch module.
-- The backward pass is left unimplemented, but works as the forward pass!
--- @class CustomLinear
CustomLinear, CustomLinearParent = torch.class('nn.CustomLinear', 'nn.Module')

--- Initialize the layer specifying the number of input and output units.
-- @param nInputUnits number of input units
-- @param nOutputUnits number of output units
function CustomLinear:__init__(nInputUnits, nOutputUnits)
  self.nInputUnits = nInputUnits
  self.nOutputUnits = nOutputUnits
  self.weight = torch.Tensor(nOutputUnits, nInputUnits):fill(0)
end

--- Compute output.
-- @param input input of type CustomDataStructure
-- @return output of type CustomDataStructure
function CustomLinear:updateOutput(input)
  self.output = CustomDataStructure.create(self.nOutputUnits)
  self.output.x1 = torch.mv(self.weight, input.x1)
  self.output.x2 = torch.mv(self.weight, input.x2)
  return self.output
end

--- Avoid backward pass.
function CustomLinear:UpdateGradInput(input, gradOutput)
  assert(false)
end

-- (3) To the the module, a one-layer network is created
-- and a simple instantiation of CustomDataStructure is passed
-- through it.
N = 10
x1 = torch.Tensor(N):fill(0)
x2 = torch.Tensor(N):fill(1)
x = CustomDataStructure.create(N)
x.x1 = x1
x.x2 = x2

model = nn.Sequential()
module = nn.CustomLinear(10, 1)
module.weight = torch.Tensor(1, 10):fill(1)
model:add(module)

y = model:forward(x)
print(y.x1)
print(y.x2)

Finally, Torch does exactly what this article is all about:

  1. First, CustomDataStructure defines a custom type, a class wrapping two Torch tensors that can be accessed via .x1 and .x2.
  2. CustomLinear directly extends nn.Module and operates on the defined data structure. The updateOutput method takes as input an instance of CustomDataStructure and applies a linear operation to both x1 and x2. Note that the backward pass is omitted for simplicity but works analogously.
  3. Finally, to test the setup, a simple CustomDataStructure is instantiated and passed through the custom linear module.

Conclusion

While the libraries show many difference, they all share tensor-orientation, i.e. using tensors as main data structure (dense or sparse) and defining operations on these tensors. However, on more complex data structures, the discussed libraries — Tensorflow, PyTorch, Theano and Torch — do not work as easily. Overall, I can only recommend Theano and Torch for deep learning on non-tensor data structures. While Theano might require some work to get it working, the example shown for Torch can easily be extended to any data structure.

References

  • [] Gernot Riegler, Ali Osman Ulusoy, Andreas Geiger. OctNet: Learning Deep 3D Representations at High Resolutions. CoRR, 2016.
  • [] Ben Graham. Sparse 3D convolutional neural networks. BMVC, 2015.
  • []Martin Engelcke, Dushyant Rao, Dominic Zeng Wang, Chi Hay Tong, Ingmar Posner. Vote3Deep: Fast Object Detection in 3D Point Clouds Using Efficient Convolutional Neural Networks. CoRR, 2016.
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.