IAM

OPENSOURCEFAN STUDYING
STUDYINGCOMPUTERSCIENCEANDMATH COMPUTERSCIENCE

Check out the latest superpixel benchmark — Superpixel Benchmark (2016) — and let me know your opinion! @david_stutz
16thFEBRUARY2017

SNIPPET

Slightly adapted example for adding new operations in Tensorflow taken from the official documentation. The files should be copied to tensorflow/core/user_ops. The new operation is compiled using bazel build -c opt //tensorflow/core/user_ops:zero_out.so from the Tensorflow root. The generated .so file can usually be found by searching bazel-bin. This code does not include the corresponding gradient function yet.

BUILD
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)
zero_out.cc
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
        c->set_output(0, c->input(0));
        return Status::OK();
    });

class ZeroOutOp : public OpKernel {
public:
    explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
    void Compute(OpKernelContext* context) override {
        const Tensor& input_tensor = context->input(0);
        auto input = input_tensor.flat<int32>();
        
        Tensor* output_tensor = NULL;
        OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
        auto output = output_tensor->flat<int32>();
        
        const int N = input.size();
        for (int i = 0; i < N; i++) {
            output(i) = 0;
        }
    }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
zero_out_test.py
import tensorflow as tf

# Adapt this path by looking for the generated zero_out.so!
zero_out_module = tf.load_op_library('~/tensorflow-master/bazel-bin/tensorflow/core/user_ops/zero_out.so')
with tf.Session(''):
    print(zero_out_module.zero_out([[1, 2], [3, 4]]).eval())

What is your opinion on the code snippet? Is it working? Let me know your thoughts in the comments below or using the following platforms: