# DAVIDSTUTZ

24thFEBRUARY2017

## Batch Normalization in Tensorflow

A simple implementation of batch normalization in Tensorflow. Note that the implementation is meant for training — i.e. for testing on individual samples, the mean and variance should be fixed in advance. Also note that `tf.nn.moments` is replaced by a custom method for computing mean and variance to avoid the problem described here.

`layers.py`
```import tensorflow as tf
import utils

def batch_normalization(name, x, **kwargs):
"""
Batch normalization layer.

:param name: scope name
:type name: str
:param x: input tensor
:type x: tensorflow.Tensor

:Keyword Arguments:
* *variance_epsilon* (float) -- epsilon to add to variance before dividing (0.0)

:return: batch normalization tensor
:rtype: tensorflow.Tensor
"""

variance_epsilon = utils.get_kwargs(kwargs, 'variance_epsilon', 0.0)
assert variance_epsilon >= 0

with tf.name_scope(name):
offset = tf.Variable(tf.constant(0.0, shape = [x.get_shape().as_list()[-1]]), name = 'offset', trainable = True)
scale = tf.Variable(tf.constant(1.0, shape = [x.get_shape().as_list()[-1]]), name = 'scale', trainable = True)

# Convolutional layer:
if len(x.get_shape().as_list()) == 4:
mean, variance = utils.moments('moments', x, [0, 1, 2])
# Fully connected layer:
else:
mean, variance = utils.moments('moments', x, [0, 1])

return tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon)
```
`utils.py`
```import tensorflow as tf

def count_elements(name, x):
"""
Count the number of elements in the given tensor.

:param name: scope name
:type name: str
:param x: input tensor
:type x: tensorflow.Tensor
:return: batch normalization tensor
:rtype: tensorflow.Tensor
"""

with tf.name_scope(name):
return tf.reduce_sum(tf.ones_like(x))

def moments(name, x, dimensions):
"""
Compute mean and variance for the given tensor along the given dimensions.

:param name: scope name
:type name: str
:param x: input tensor
:type x: tensorflow.Tensor
:param dimensions: list of dimensions to compute moments over
:type dimensions: [int]
:return: moments tensors
:rtype: (tensorflow.Tensor, tensorflow.Tensor)
"""

with tf.name_scope(name):
sum = tf.reduce_sum(x, dimensions)
squared_sum = tf.reduce_sum(tf.mul(x, x), dimensions)
elements = count_elements('elements', x)/count_elements('sum_elements', sum)

mean = tf.div(sum, elements)
variance = tf.sub(tf.div(squared_sum, elements), tf.mul(mean, mean))

return mean, variance
```

