IAM

ARTICLE

More Examples for Working with Torch

This article is a short follow-up on my initial collection of examples for getting started with Torch. In the meanwhile, through a series of additional articles, the corresponding GitHub repository has grown, including not only basic examples but also more advanced examples such as variational auto-encoders, generative adversarial networks or adversarial auto encoders. This article aims to provide a short overview of the added examples.

This article is meant as a short follow-up of my 2017 article on getting started with Torch. Since, I have added several additional examples to the GitHub repository:

Torch Examples on GitHub

These examples partly correspond to several articles, as outlined below, and include more advanced concepts such as variational auto-encoders, residual convolutional modules, or generative adversarial networks.

Overview of Articles

Following my first article on working with torch, Examples for Getting Started with Torch, I published several more advanced articles on variational auto-encoders and their variants, implementing Torch modules in C and CUDA, or articles on very specific topics such as PointNets [][][] or OctNets []. Altogether, these articles represent a very nice article series for diving deep into Torch and Lua with only minimal requirements: a basic understanding or programming (for example in Python or MatLab, which are both similar to Lua) and beginner background in deep learning.

Articles:

Many more custom modules and utilities for Torch can also be found in the following repository, which is part off my work on weakly-supervised 3D shape completion: davidstutz/aml-improved-shape-completion.

Additionally, I added several more examples to the davidstutz/torch-examples repository, some of which are described in detail below.

Custom Data Structures

As outlined briefly in my article on custom data structures in deep learning frameworks, the following example illustrates how to use implement and subsequently use custom data structures for deep learning in Torch. A great example for a custom data structure implemented and used in Torch are OctNets [].

-- Small example to test forward and packward passes of custom data structures.

require('math')
require('torch')
require('nn')
require('lfs')

package.path = package.path .. ";" .. lfs.currentdir() .. '?.lua'
init = require('init')

--- @class CustomDataStructure
-- This calss will be out simple test data structure.
CustomDataStructure = {}
CustomDataStructure.__index = CustomDataStructure

--- Creates a new CustomDataStructure with 0-vectors of the given size
-- @param n vector size
function CustomDataStructure.create(n)
  local cds = {}
  setmetatable(cds, CustomDataStructure)
  cds.x1 = torch:Tensor(n):fill(0)
  cds.x2 = torch:Tensor(n):fill(0)
  return cds
end

--- @class CustomLinear
CustomLinear, CustomLinearParent = torch.class('nn.CustomLinear', 'nn.Module')

--- Initialize the layer specifying the number of input and output units.
-- @param nInputUnits number of input units
-- @param nOutputUnits number of output units
function CustomLinear:__init__(nInputUnits, nOutputUnits)
  self.nInputUnits = nInputUnits
  self.nOutputUnits = nOutputUnits
  self.weight = torch.Tensor(nOutputUnits, nInputUnits):fill(0)
end

--- Compute output.
-- @param input input of type CustomDataStructure
-- @return output of type CustomDataStructure
function CustomLinear:updateOutput(input)
  self.output = CustomDataStructure.create(self.nOutputUnits)
  self.output.x1 = torch.mv(self.weight, input.x1)
  self.output.x2 = torch.mv(self.weight, input.x2)
  return self.output
end

--- Avoid backward pass.
function CustomLinear:UpdateGradInput(input, gradOutput)
  assert(false)
end

N = 10
x1 = torch.Tensor(N):fill(0)
x2 = torch.Tensor(N):fill(1)
x = CustomDataStructure.create(N)
x.x1 = x1
x.x2 = x2

model = nn.Sequential()
module = nn.CustomLinear(10, 1)
module.weight = torch.Tensor(1, 10):fill(1)
model:add(module)

y = model:forward(x)
print(y.x1)
print(y.x2)
Listing : Implementation of custom data structure, consisting of tensors (x1 and x2) and a linear module operating on this data structure.

Listing shows the implementation of CustomDataStructure and CustomLinear, a deep learning module operating on CustomDataStructure objects:

  1. First, CustomDataStructure defines a custom type, a class wrapping two Torch tensors that can be accessed via .x1 and .x2.
  2. CustomLinear directly extends nn.Module and operates on the defined data structure. The updateOutput method takes as input an instance of CustomDataStructure and applies a linear operation to both x1 and x2. Note that the backward pass is omitted for simplicity but works analogously.
  3. Finally, to test the setup, a simple CustomDataStructure is instantiated and passed through the custom linear module.

Auto-Encoder with Residual Units

The following example extends the basic auto-encoder example from this article and introduces a residual blocks (for example, see []):

-- Auto-encoder example usign residual units - residual units do not really make
-- sense in this context, but just as example.
-- Some good examples for residual units:
-- - https://github.com/arunpatala/residual.mnist
-- - https://github.com/gcr/torch-residual-networks

require('math')
require('torch')
require('nn')
require('optim')
require('lfs')

package.path = package.path .. ";" .. lfs.currentdir() .. '?.lua'
init = require('init')

N = 1000
D = 100

inputs = torch.Tensor(N, D)
outputs = torch.Tensor(N, D)

for i = 1, N do 
  outputs[i] = torch.ones(D)
  inputs[i] = torch.cmul(outputs[i], torch.randn(D)*0.05 + 1)

  if i%100 == 0 then
    print('[Data] '..i)
  end
end

--- Simple residual layer.
-- @param input the input model
-- @param nInput input units of linear layer
-- @param nOutput output units of linear layer
-- @return model with residual layer
function addResidualLayer(input, nInput, nOutput)
  local cat = nn.ConcatTable()
  local unit = nn.Linear(nInput, nOutput)
  cat:add(unit)
  cat:add(nn.Identity())

  input:add(cat)
  input:add(nn.CAddTable())
  input:add(nn.Tanh(true))
  return input
end

model = nn.Sequential()
model = addResidualLayer(model, D, D)
model:add(nn.Linear(D, D))
model = init(model, 'xavier')

batchSize = 10
learningRate = 0.01
momentum = 0.9
weightDecay = 0.05
criterion = nn.AbsCriterion()  

parameters, gradParameters = model:getParameters()

for t = 1, 2500 do

  -- Sample a random batch from the dataset.
  local shuffle = torch.randperm(N)
  shuffle = shuffle:narrow(1, 1, batchSize)
  shuffle = shuffle:long()

  local input = inputs:index(1, shuffle)
  local output = outputs:index(1, shuffle)

  --- Definition of the objective on the current mini-batch.
  -- This will be the objective fed to the optimization algorithm.
  -- @param x input parameters
  -- @return object value, gradients
  local feval = function(x)

     -- Get new parameters.
     if x ~= parameters then
        parameters:copy(x)
     end

     -- Reset gradients
     gradParameters:zero()

     -- Evaluate function on mini-batch.
     local pred = model:forward(input)
     local f = criterion:forward(input, output)

     -- Estimate df/dW.
     local df_do = criterion:backward(pred, output)
     model:backward(input, df_do)

     -- weight decay
     if weightDecay > 0 then
        f = f + weightDecay * torch.norm(parameters,2)^2/2
        gradParameters:add(parameters:clone():mul(weightDecay))
     end

     -- return f and df/dX
     return f, gradParameters
  end

  sgd_state = sgd_state or {
      learningRate = learningRate,
      momentum = momentum,
      learningRateDecay = 5e-7
  }

  -- Returns the new parameters and the objective evaluated
  -- before the update.
  p, f = optim.sgd(feval, parameters, sgd_state)

  print('[Training] '..t..': '..f[1])
end 
Listing : Example of a residual layer.

Listing shows an example of a residual layer, implemented as function addResidualLayer. Note that the auto-encoder in this example does not contain any bottleneck (meaning there is no explicit encoder and decoder) as the main goal is the usage of the residual block.

Generative Adversarial Networks

As two more extensions of auto-encoders as generative models, the repository also contains examples for adversarial auto-encoders [] and generative adversarial networks [].

Generative adversarial networks simultaneously train a generator, generating samples from a simple (for example, uniform) distribution, and a discriminator, distinguishing the generated samples from a reference distribution (for example, real images). By trying to fool the generator, it can generate more and more realistic samples, thereby matching the reference distribution. More details can be found in the original paper [].

The following listing exemplarily shows a Torch implementation of a generative adversarial network. As such networks are usually difficult to train, the example shows how to learn a simple low-dimensional Gaussian distribution.

-- Generative adversarial network.

require('nngraph')
require('torch')
require('nn')
require('cunn')
require('optim')
require('randomkit')

-- (1) Simple wrapper around optim.rmsprop allowing to turn
-- of optmization for either the descriminator or generator using
-- config.optimize.
--- Wrapper around RMSProp checking config.optimize.
-- @param opfunc objective to optimize
-- @param x parameters to optimize
-- @param config configuration
function rmsprop(opfunc, x, config)
    if config.optimize then
      return optim.rmsprop(opfunc, x, config)
    else
      local fx, dfdx = opfunc(x)
      return x, {fx}
    end
end

-- (2) Data generation; the example will try to learn a one-dimension
-- Gaussian with the given mean and standard deviation
D = 1
nD = 2
N = 10000
mean = 3
std = 1

torch.setdefaulttensortype('torch.CudaTensor')
trainData = torch.Tensor(N, D):normal(mean, std)
testData = torch.Tensor(N, D):normal(mean, std)
print('[Data] mean ' .. torch.mean(trainData) .. ' std ' .. torch.std(trainData))

-- (3) The discriminator network is a simple classification network.
-- Depending on the task and the generator, the discriminator can be
-- made weaker by considering dropout, and made stronger by considering
-- batch normalization or a deeper/wider architecture. This might be necessary
-- if either the discriminator learns to fast, or the generator is too powerful.
-- Discriminator network.
model_D = nn.Sequential()
model_D:add(nn.Linear(D, 4*D))
model_D:add(nn.ReLU(true))
--model_D:add(nn.Dropout())
model_D:add(nn.Linear(4*D, 4*D))
model_D:add(nn.ReLU(true))
--model_D:add(nn.Dropout())
model_D:add(nn.Linear(4*D, 1))
model_D:add(nn.Sigmoid())

-- (4) The generator takes a nD-dimensional noise input and turns it into
-- a one-dimensional value which is supposed to be distributed like a Gaussian.
-- The generator can be made stronger with a deeper architecture or by
-- considering batch normalization, if necessary.
-- Generator network.
model_G = nn.Sequential()
model_G:add(nn.Linear(nD, 4*D))
--model_G:add(nn.BatchNormalization(4*D))
model_G:add(nn.ReLU(true))
--model_G:add(nn.Linear(4*D, 4*D))
--model_G:add(nn.BatchNormalization(4*D))
--model_G:add(nn.ReLU(true))
model_G:add(nn.Linear(4*D, D))

model_D = model_D:cuda()
model_G = model_G:cuda()

criterion = nn.BCECriterion()
criterion = criterion:cuda()

-- Retrieve parameters and gradients.
parameters_D, gradParameters_D = model_D:getParameters()
parameters_G, gradParameters_G = model_G:getParameters()

-- (5) Training will be governed by the regular learning parameters;
-- in addition, K will denote the number the disrcriminator is updated per step.
K = 3
batchSize = 8
momentum = 0
learningRate = 0.005
epochs = 5

sgdState_D = {
  learningRate = learningRate,
  momentum = momentum,
  learningRateDecay = 0.95,
  decayStep = math.floor(N/batchSize),
  optimize = true
}

sgdState_G = {
  learningRate = learningRate,
  momentum = momentum,
  learningRateDecay = 0.95,
  decayStep = math.floor(N/batchSize),
  optimize = true
}

-- For monitoring training.
smoothedLoss_G = 0
smoothedLoss_D = 0
smoothedCount = 0
optimizedCount_G = 0
optimizedCount_D = 0

-- (6) Training.
for t = 1, epochs do

  -- Do one epoch.
  local dataBatchSize = batchSize/2
  for t = 1, math.floor(N/dataBatchSize) do

    -- Random input for data generator.
    local inputs = torch.Tensor(batchSize, D)
    local targets = torch.Tensor(batchSize)
    local noise_inputs = torch.Tensor(batchSize, nD)

    --- Function to evaluate the discriminator.
    -- @param x parameters
    -- @return function value and derivative
    local fevalD = function(x)
      collectgarbage()

      if x ~= parameters_D then
        parameters_D:copy(x)
      end

      gradParameters_D:zero()

      local outputs = model_D:forward(inputs)

      -- Compute error on real and fake data.
      -- Remember that the first dataBatchSize elements are real.
      err_R = criterion:forward(outputs:narrow(1, 1, dataBatchSize), targets:narrow(1, 1, dataBatchSize))
      err_F = criterion:forward(outputs:narrow(1, dataBatchSize + 1,dataBatchSize), targets:narrow(1, dataBatchSize + 1, dataBatchSize))
      err = criterion:forward(outputs, targets)

      -- err_R, err_F will be roughly 0.7 if training gets stuck, so the margin should be
      -- a bit above 0.3!
      local margin = 0.31
      sgdState_D.optimize = true
      sgdState_G.optimize = true

      if err_F < margin or err_R < margin then
         sgdState_D.optimize = false
      end
      if err_F > (1.0 - margin) or err_R > (1.0 - margin) then
         sgdState_G.optimize = false
      end
      if math.abs(err - 0.5) < 0.01 then
        sgdState_G.optimize = false
      end

      -- Avoid a deadlock.
      -- Note that deadlock means both not optimizing and not challening
      -- each other (i.e. if both are really bad.
      if sgdState_G.optimize == false and sgdState_D.optimize == false then
        local r = math.random()
        if r > 0.5 then
          sgdState_G.optimize = true
        else
          sgdState_D.optimize = true
        end
      end

      local f = criterion:forward(outputs, targets)

      local df_do = criterion:backward(outputs, targets)
      model_D:backward(inputs, df_do)

      return f, gradParameters_D
    end

    --- Function to evaluate the generator.
    -- @param x parameters
    -- @return function value and derivative
    local fevalG = function(x)
      collectgarbage()

      if x ~= parameters_G then
        parameters_G:copy(x)
      end

      gradParameters_G:zero()

      local samples = model_G:forward(noise_inputs)
      local outputs = model_D:forward(samples)
      local f = criterion:forward(outputs, targets)

      local df_samples = criterion:backward(outputs, targets)
      model_D:backward(samples, df_samples)

      local df_do = model_D.modules[1].gradInput
      model_G:backward(noise_inputs, df_do)

      return f, gradParameters_G
    end

    -- (6.1) The discriminator is updated K times; each time, a batch consisting of real
    -- and fake data is chosen and the discriminator is supposed to distinguish them.
    for k = 1, K do

      -- (6.1.1) Real data chosen randomly from the trianing set.
      local j = 1
      for i = t, math.min(t + dataBatchSize - 1, N) do
        local sample = trainData[math.random(N)]
        inputs[j] = sample:clone()
        j = j + 1
      end

      targets[{{1,dataBatchSize}}]:fill(1)

      -- (6.1.2) Fake data comes form a uniform distribution.
      -- Instead, other distributions different enough from the target distributions could
      -- be chosen.
      local samples = model_G:forward(torch.Tensor(dataBatchSize, nD):uniform(-1, 1))
      for i = 1, dataBatchSize do
        inputs[j] = samples[i]:clone()
        j = j + 1
      end

      targets[{{dataBatchSize + 1, batchSize}}]:fill(0)

      -- (6.1.3) fevalD additionally decides whether to update the discriminator or
      -- generator. Details can be found in fevalD; explained in short,
      -- the discriminators error on real and fake data is considered.
      -- If the discriminator is too strong, it will not be updated further,
      -- if it is too weak, the generator is not updated further.
      p, f_D = rmsprop(fevalD, parameters_D, sgdState_D)

      if sgdState_D.optimize then
        optimizedCount_D = optimizedCount_D + 1
      end
    end

    -- (6.2) Update the generator, giving noise as input;
    -- fevalG encodes the objective, i.e. the generator is trained to fool
    -- the discriminator.
    noise_inputs:uniform(-1, 1)
    targets:fill(1)
    p, f_G = rmsprop(fevalG, parameters_G, sgdState_G)

    if sgdState_G.optimize then
      optimizedCount_G = optimizedCount_G + 1
    end

    smoothedLoss_G = smoothedLoss_G + f_G[1]
    smoothedLoss_D = smoothedLoss_D + f_D[1]
    smoothedCount = smoothedCount + 1

    if t%50 == 0 then
      print('[Training] G: ' .. smoothedLoss_G/smoothedCount .. ' (' .. optimizedCount_G/smoothedCount .. ') '
        .. 'D: ' .. smoothedLoss_D/smoothedCount .. ' (' .. optimizedCount_D/(K*smoothedCount) .. ')')

      optimizedCount_G = 0
      optimizedCount_D = 0
      smoothedLoss_G = 0
      smoothedLoss_D = 0
      smoothedCount = 0
    end
  end

  -- (6.4) Both the generator and the discriminator are tested.
  -- The discriminator is tested using real and fake data, for testing the generator,
  -- the mean and standard deviation of the predicted outputs are inspected.
  local noise_inputs = torch.Tensor(N, nD):uniform(-1, 1)
  local samples = model_G:forward(noise_inputs)

  local preds = model_D:forward(samples)
  preds[preds:lt(0.5)] = 0
  preds[preds:gt(0.5)] = 1
  local accuracy = 1 - torch.sum(preds)/N

  preds = model_D:forward(testData)
  preds[preds:lt(0.5)] = 0
  preds[preds:gt(0.5)] = 1
  accuracy = 0.5*accuracy + 0.5*torch.sum(preds)/N

  print('[Testing] error (D): ' .. accuracy .. ', mean (G): ' .. torch.mean(samples) .. ', std (G): ' .. torch.std(samples))

  sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7)
  sgdState_D.learningRate = math.max(learningRate*0.99^t, 0.000001)
  sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7)
  sgdState_G.learningRate = math.max(learningRate*0.99^t, 0.000001)
end 
Listing : Implementation of a generative adversarial network to learn to generate low-dimensional Gaussian samples.

Listing [] can be broken down as follows:

  1. The function rmsprop represents a simple wrapper around the RMSProp optimizer allowing to turn off optimization for either the generator or the discriminator.
  2. As example, samples from a low-dimensional Gaussian are generated as training data.
  3. The discriminator is a simple, binary classification network. In order to balance discriminator and generator, the discriminator can be made weaker using dropout, or made stronger using batch normalization and/or a wider/deeper architecture.
  4. The generator takes as input uniform noise and predicts low-dimensional samples, which are supposed to be distributed according to a Gaussian distribution. Similar to the discriminator, the generator gets stronger with a deeper/wider architecture or batch normalization.
  5. In addition to the usual hyper-parameters, K specifies the number of discriminator updates per generator update; this is done in order to balance training.
  6. Training:
    1. The discriminator is updated:
      1. Real data is chosen at random from the training set, meaning the true distribution.
      2. Fake data comes from a uniform distribution which is fed through the current generator.
      3. fevalD also decides whether to update the discriminator based on its error; if the discriminator is too strong (resulting in low error), it is not updated and only the generator is updated instead.
    2. The generator is updated based on the output of the discriminator, with the goal tofool the discriminator into believing that the generated samples come from the true distribution.
  7. Both generator and discriminator are tested: the discriminator can easily be tested to distinguish real and fake data; the generator can only be tested by inspecting the generated samples. In our case this is done using the statistics of the generated samples; for images, qualitative inspection is usualyl required.

Adversarial Auto-Encoders

Adversarial auto-encoders enforce a simple prior on the latent code of regular auto-encoders, for example a Gaussian prior. The currently learned distribution by the encoder and the prior are compared using an additional discriminator. The discriminator tries to distinguish between samples from a Gaussian distribution and from the encoder; the encoder tries to generate samples indistinguishable from this Gaussian prior. More details can be found in []. Listing [] shows a possible implementation; in contrast to the above example for a generative adversarial network, I will not go through the implementation in detail.

-- Trying to implement an adversarial auto-encoder.

require('nngraph')
require('torch')
require('cutorch')
require('nn')
require('cunn')
require('optim')
require('image')
require('lfs')

package.path = package.path .. ";" .. lfs.currentdir() .. '?.lua'
init = require('init')

-- Interfers with image
--torch.setdefaulttensortype('torch.CudaTensor')

--- RMSProp on the given objective.
-- @param opfunc objective to optimize
-- @param x parameters to optimize
-- @param config configuration
function rmsprop(opfunc, x, config)
    -- (0) Check the configuration.
    assert(config)
    assert(config.learningRate)
    assert(config.momentum)
    assert(config.learningRateDecay)
    assert(config.numUpdates)
    assert(config.optimize ~= nil)

    -- (1) Update learning rate.
    if config.numUpdates%config.decayStep == 0 then
      config.learningRate = config.learningRate*config.learningRateDecay
    end

    local lr = config.learningRate
    local alpha = config.alpha or 0.9
    local epsilon = config.epsilon or 1e-8

    -- (2) Evaluate f(x) and df/dx.
    local fx, dfdx = opfunc(x)

    -- If the model is set to optimize, update parameters:
    if config.optimize == true then
        -- (3) Initialize mean square values and square gradient storage.
        if not config.m then
          config.m = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
          config.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx)
        end

        -- (4) Calculate new (leaky) mean squared values.
        config.m:mul(alpha)
        config.m:addcmul(1.0 - alpha, dfdx, dfdx)

        -- (5) Perform update.
        config.tmp:sqrt(config.m):add(epsilon)
        x:addcdiv(-lr, dfdx, config.tmp)
    end
    config.numUpdates = config.numUpdates + 1

    -- return x*, f(x) before optimization
    return x, {fx, dfdx}
end

--- Train a GAN, i.e. the generator and discriminator.
-- @param data training data
function train(data)
  model_D:training()
  model_Dec:training()
  model_EncDec:training()

  epoch = epoch or 1
  local N = data:size(1)
  local dataBatchSize = batchSize/2

  -- Do one epoch.
  for t = 1, math.floor(N/dataBatchSize) do

    -- inputs holds the true inputs for reconstruction.
    -- targets will be adapted to the targets in each stage.
    -- codes will hold half true half random codes.
    local inputs = torch.Tensor(batchSize, 1, H, W)
    local outputs = torch.Tensor(dataBatchSize, 1, H, W)
    local targets = torch.Tensor(batchSize)
    local codes = torch.Tensor(batchSize, cD)

    inputs = inputs:cuda()
    outputs = outputs:cuda()
    targets = targets:cuda()
    codes = codes:cuda()

    --- Function to evaluate the discriminator.
    -- @param x parameters
    -- @return function value and derivative
    local fevalD = function(x)
      collectgarbage()

      if x ~= parameters_D then
        parameters_D:copy(x)
      end

      gradParameters_D:zero()

      -- Foward pass.
      local preds = model_D:forward(inputs)

      -- Compute error on real and fake data.
      -- Remember that the first dataBatchSize elements are real.
      err_R = criterion:forward(preds:narrow(1, 1, dataBatchSize), targets:narrow(1, 1, dataBatchSize))
      err_F = criterion:forward(preds:narrow(1, dataBatchSize + 1,dataBatchSize), targets:narrow(1, dataBatchSize + 1, dataBatchSize))
      err = criterion:forward(preds, targets)
      --print(err_R, err_F)

      -- err_R, err_F will be roughly 0.7 if training gets stuck, so the margin should be
      -- a bit above 0.3!
      local margin = 0.31
      sgdState_D.optimize = true
      sgdState_Dec.optimize = true

      if err_F < margin or err_R < margin then
         sgdState_D.optimize = false
      end
      if err_F > (1.0 - margin) or err_R > (1.0 - margin) then
         sgdState_Dec.optimize = false
      end
      if math.abs(err - 0.5) < 0.01 then
        sgdState_Dec.optimize = false
      end

      -- Avoid a deadlock.
      -- Note that deadlock means both not optimizing and not challening
      -- each other (i.e. if both are really bad.
      if sgdState_Dec.optimize == false and sgdState_D.optimize == false then
        local r = math.random()
        if r > 0.5 then
          sgdState_Dec.optimize = true
        else
          sgdState_D.optimize = true
        end
      end

      local f = criterion:forward(preds, targets)

      local df_do = criterion:backward(preds, targets)
      model_D:backward(inputs, df_do)

      -- L_1 or L_2 penalties/regularizers.
      if coefL1 > 0 or coefL2 > 0 then
        -- Loss:
        f = f + coefL1 * torch.norm(parameters_D, 1)
        f = f + coefL2 * torch.norm(parameters_D, 2)^2/2

        -- Gradients:
        gradParameters_D:add(torch.sign(parameters_D):mul(coefL1) + parameters_D:clone():mul(coefL2))
      end

      --print('[Training][' .. t .. '] D ' .. f .. '(' .. gradParameters_D:norm() .. ')')
      return f, gradParameters_D
    end

    --- Function to evaluate the auto-encoder.
    -- @param x parameters
    -- @return function value and derivative
    local fevalEncDec = function(x)
      collectgarbage()

      if x ~= parameters_EncDec then
        parameters_EncDec:copy(x)
      end

      gradParameters_EncDec:zero()

      -- Foward pass.
      local preds = model_EncDec:forward(inputs)
      local f = criterion:forward(preds, outputs)

      --  Backward pass.
      local df = criterion:backward(preds, outputs)
      model_EncDec:backward(inputs, df)

      --print('[Training][' .. t .. '] EncDec ' .. f .. '(' .. gradParameters_EncDec:norm() .. ')')
      return f, gradParameters_EncDec
    end

    --- Function to evaluate the decoder/generator.
    -- @param x parameters
    -- @return function value and derivative
    local fevalDec = function(x)
      collectgarbage()

      if x ~= parameters_Dec then
        parameters_Dec:copy(x)
      end

      gradParameters_Dec:zero()

      -- Foward pass.
      local samples = model_Dec:forward(codes)
      local preds = model_D:forward(samples)
      local f = criterion:forward(preds, targets)

      --  Backward pass.
      local df_samples = criterion:backward(preds, targets)
      model_D:backward(samples, df_samples)

      local df = model_D.modules[1].gradInput
      model_Dec:backward(codes, df)

      --print('[Training][' .. t .. '] Dec ' .. f .. '(' .. gradParameters_Dec:norm() .. ')')
      return f, gradParameters_Dec
    end

    ----------------------------------------------------------------------
    -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))).
    -- Get half a minibatch of real, half fake.
    for k = 1, K do

      -- (1.1. Real data; first fetch random samples, then reconstruct them
      -- take the reconstructions as inputs.
      for i = 1, dataBatchSize do
        local sample = data[math.random(N)]
        inputs[i] = sample:clone()
      end


      local samples = model_EncDec:forward(inputs:narrow(1, 1, dataBatchSize))
      local j = 1
      for i = 1, dataBatchSize do
        inputs[j] = samples[i]:clone()
        j = j + 1
      end

      -- (1.2) Real targets.
      targets[{{1,dataBatchSize}}]:fill(1)

      -- (1.3) Sampled data.
      -- Note not to sample from a normal distribution instead!
      codes:uniform(-1, 1)
      samples = model_Dec:forward(codes:narrow(1, 1, dataBatchSize))
      for i = 1, dataBatchSize do
        inputs[j] = samples[i]:clone()
        j = j + 1
      end

      -- (1.4) Sampled targets.
      targets[{{dataBatchSize + 1, batchSize}}]:fill(0)

      p, f_D = rmsprop(fevalD, parameters_D, sgdState_D)

      if sgdState_D.optimize then
        optimizedCount_D = optimizedCount_D + 1
      end
    end

    ----------------------------------------------------------------------
    -- (2) Update auto-encoder network: minimize reconstruction loss
    -- Get a full mini-batch of data samples.
    for l = 1, L do
      for i = 1, batchSize do
        local sample = data[math.random(N)]
        inputs[j] = sample:clone()
      end

      outputs = inputs:clone()

      p, f_EncDec = rmsprop(fevalEncDec, parameters_EncDec, sgdState_EncDec)
      optimizedCount_EncDec = optimizedCount_EncDec + 1
    end

    ----------------------------------------------------------------------
    -- (2) Update G network: maximize log(D(G(z))).
    for m = 1, M do
      codes:uniform(-1, 1)
      targets:fill(1)
      p, f_Dec = rmsprop(fevalDec, parameters_Dec, sgdState_Dec)

      if sgdState_Dec.optimize then
        optimizedCount_Dec = optimizedCount_Dec + 1
      end
    end

    smoothedLoss_D = smoothedLoss_D + f_D[1]
    smoothedLoss_Dec = smoothedLoss_Dec + f_Dec[1]
    smoothedLoss_EncDec = smoothedLoss_EncDec + f_EncDec[1]
    smoothedGradParameters_D = smoothedGradParameters_D + torch.norm(f_D[2], 2)
    smoothedGradParameters_Dec = smoothedGradParameters_Dec + torch.norm(f_Dec[2], 2)
    smoothedGradParameters_EncDec = smoothedGradParameters_EncDec + torch.norm(f_EncDec[2], 2)
    smoothedCount = smoothedCount + 1

    if t%50 == 0 then
      print('[Training] EncDec: ' .. smoothedLoss_EncDec/smoothedCount .. ' (' .. smoothedGradParameters_EncDec/smoothedCount .. ', ' .. optimizedCount_EncDec/smoothedCount .. ') '
        .. ' Dec: ' .. smoothedLoss_Dec/smoothedCount .. ' (' .. smoothedGradParameters_Dec/smoothedCount .. ', ' .. optimizedCount_Dec/(K*smoothedCount) .. ')'
      .. ' D: ' .. smoothedLoss_D/smoothedCount .. ' (' .. smoothedGradParameters_D/smoothedCount .. ', ' .. optimizedCount_D/(K*smoothedCount) .. ')')

      optimizedCount_D = 0
      optimizedCount_Dec = 0
      optimizedCount_EncDec = 0
      smoothedGradParameters_D = 0
      smoothedGradParameters_Dec = 0
      smoothedGradParameters_EncDec = 0
      smoothedLoss_D = 0
      smoothedLoss_Dec = 0
      smoothedLoss_EncDec = 0
      smoothedCount = 0
    end
  end

  -- TODO save snapshots!
  --torch.save(filename, {D = model_D, G = model_G, opt = opt})

  epoch = epoch + 1
end

--- Test the GAN.
-- @param data real data to test on
function test(data)

  -- Test reconstruction.
  local preds = model_EncDec:forward(data)
  print('[Testing] Reconstruction (abs) ' .. torch.sum(torch.abs(preds - data))/(data:size(1)*data:size(3)*data:size(4)))

  -- Write first few data images and reconstructions to images.
  for i = 1, 100 do
    image.save('results/' .. i .. '.png', data[i][1])
    image.save('results/' .. i .. '_rec.png', preds[i][1])
  end

  -- Test discriminator.
  preds = model_D:forward(preds)
  preds[preds:gt(0.5)] = 1
  preds[preds:lt(0.5)] = 0
  print('[Testing] Discriminator (accuracy) ' .. torch.sum(preds)/data:size(1))
end

-- Training parameters.
K = 4 -- Number of D rounds per iteration ...
L = 2 -- Number of EncDec rounds per iterion ...
M = 1 -- Number of Dec rounds per iteration ...
coefL1 = 0
coefL2 = 0
batchSize = 8
momentum = 0
learningRate = 0.005

-- Data parameters.
H = 16
W = 16
rH = 8
rW = 8
cD = 2
N = 5000

-- Fix random seed.
--torch.manualSeed(1)

-- Generate rectangle data.
trainData = torch.Tensor(N, 1, H, W):fill(0)
for i = 1, N do
  local h = torch.random(2, rH)
  local w = torch.random(2, rW)
  local aH = torch.random(1, H - h)
  local aW = torch.random(1, W - w)
  trainData[i][1]:sub(aH, aH + h, aW, aW + w):fill(1)
end

testData = torch.Tensor(N, 1, H, W):fill(0)
for i = 1, math.floor(0.1*N) do
  local h = torch.random(2, rH)
  local w = torch.random(2, rW)
  local aH = torch.random(1, H - h)
  local aW = torch.random(1, W - w)
  testData[i][1]:sub(aH, aH + h, aW, aW + w):fill(1)
end

trainData = trainData:cuda()
testData = testData:cuda()

-- Discriminator network.
model_D = nn.Sequential()
model_D:add(nn.SpatialConvolution(1, 4, 3, 3, 1, 1, 1, 1))
model_D:add(nn.SpatialBatchNormalization(4))
model_D:add(nn.ReLU())
model_D:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0))
model_D:add(nn.SpatialConvolution(4, 8, 3, 3, 1, 1, 1, 1))
model_D:add(nn.SpatialBatchNormalization(8))
model_D:add(nn.ReLU())
model_D:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0 ,0))
model_D:add(nn.SpatialConvolution(8, 16, 3, 3, 1, 1, 1, 1))
model_D:add(nn.SpatialBatchNormalization(16))
model_D:add(nn.ReLU())
model_D:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0))
model_D:add(nn.View(4*16))
model_D:add(nn.Linear(4*16, 1))
model_D:add(nn.Sigmoid())
--model_D = init(model_D)

-- Encoder.
model_Enc = nn.Sequential()
model_Enc:add(nn.SpatialConvolution(1, 4, 3, 3, 1, 1, 1, 1))
model_Enc:add(nn.ReLU())
model_Enc:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0))
model_Enc:add(nn.SpatialConvolution(4, 8, 3, 3, 1, 1, 1, 1))
model_Enc:add(nn.ReLU())
model_Enc:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0 ,0))
model_Enc:add(nn.SpatialConvolution(8, 16, 3, 3, 1, 1, 1, 1))
model_Enc:add(nn.ReLU())
model_Enc:add(nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0))
model_Enc:add(nn.View(4*16))
model_Enc:add(nn.Linear(4*16, cD))
--model_Enc = init(model_Enc)

-- Decoder.
model_Dec = nn.Sequential()
model_Dec:add(nn.Linear(cD, 4*16))
model_Dec:add(nn.View(16, 2, 2))
model_Dec:add(nn.SpatialUpSamplingNearest(2))
model_Dec:add(nn.SpatialConvolution(16, 8, 3, 3, 1, 1, 1, 1))
model_Dec:add(nn.ReLU())
model_Dec:add(nn.SpatialUpSamplingNearest(2))
model_Dec:add(nn.SpatialConvolution(8, 4, 3, 3, 1, 1, 1, 1))
model_Dec:add(nn.ReLU())
model_Dec:add(nn.SpatialUpSamplingNearest(2))
model_Dec:add(nn.SpatialConvolution(4, 1, 3, 3, 1, 1, 1, 1))
model_Dec:add(nn.Sigmoid())
--model_Dec = init(model_Dec)

-- Auto-encoder.
model_EncDec = nn.Sequential()
model_EncDec:add(model_Enc:clone('weight', 'bias', 'gradWeight', 'gradBias'))
model_EncDec:add(model_Dec:clone('weight', 'bias', 'gradWeight', 'gradBias'))

model_D = model_D:cuda()
model_Dec = model_Dec:cuda()
model_EncDec = model_EncDec:cuda()

-- Loss function: negative log-likelihood.
criterion = nn.BCECriterion()
criterion = criterion:cuda()

-- Retrieve parameters and gradients.
parameters_D, gradParameters_D = model_D:getParameters()
parameters_Dec, gradParameters_Dec = model_Dec:getParameters()
parameters_EncDec, gradParameters_EncDec = model_EncDec:getParameters()

-- Training parameters
sgdState_D = {
  learningRate = learningRate,
  momentum = momentum,
  learningRateDecay = 0.95,
  decayStep = math.floor(N/batchSize), -- Update learning rate each epoch.
  optimize = true,
  numUpdates = 0
}

sgdState_Dec = {
  learningRate = learningRate,
  momentum = momentum,
  learningRateDecay = 0.95,
  decayStep = math.floor(N/batchSize), -- Update learning rate each epoch.
  optimize = true,
  numUpdates = 0
}

sgdState_EncDec = {
  learningRate = learningRate,
  momentum = momentum,
  learningRateDecay = 0.95,
  decayStep = math.floor(N/batchSize), -- Update learning rate each epoch.
  optimize = true,
  numUpdates = 0
}

-- For printing smoothed values.
smoothedGradParameters_D = 0
smoothedGradParameters_Dec = 0
smoothedGradParameters_EncDec = 0
smoothedLoss_D = 0
smoothedLoss_Dec = 0
smoothedLoss_EncDec = 0
smoothedCount = 0

-- Counts optimization to see how often generator and discriminator
-- get updates.
optimizedCount_D = 0
optimizedCount_Dec = 0
optimizedCount_EncDec = 0

while true do
  train(trainData)
  test(testData)

  sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7)
  sgdState_D.learningRate = math.max(learningRate*0.99^epoch, 0.000001)
  sgdState_Dec.momentum = math.min(sgdState_Dec.momentum + 0.0008, 0.7)
  sgdState_Dec.learningRate = math.max(learningRate*0.99^epoch, 0.000001)
  sgdState_EncDec.momentum = math.min(sgdState_EncDec.momentum + 0.0008, 0.7)
  sgdState_EncDec.learningRate = math.max(learningRate*0.99^epoch, 0.000001)
end 
Listing : An example of an adversarial auto-encoder implemented in Torch.
  • [] Charles Ruizhongtai Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. CoRR abs/1612.00593 (2016).
  • [] Charles Ruizhongtai Qi, Li Yi, Hao Su, Leonidas J. Guibas. PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space. CoRR abs/1706.02413 (2017).
  • [] Haoqiang Fan, Hao Su, Leonidas J. Guibas. A Point Set Generation Network for 3D Object Reconstruction from a Single Image. CoRR abs/1612.00603 (2016).
  • [] Gernot Riegler, Ali Osman Ulusoy, Andreas Geiger. OctNet: Learning Deep 3D Representations at High Resolutions. CVPR 2017: 6620-6629.
  • [] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning for Image Recognition. CoRR abs/1512.03385 (2015).
  • [] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron C. Courville, Yoshua Bengio. Generative Adversarial Networks. CoRR abs/1406.2661 (2014).
  • [] Alireza Makhzani, Jonathon Shlens, Navdeep Jaitly, Ian J. Goodfellow. Adversarial Autoencoders. CoRR abs/1511.05644 (2015).

What is your opinion on this article? Did you find it interesting or useful? Let me know your thoughts in the comments below: