Check out our CVPR'18 paper on weakly-supervised 3D shape completion — and let me know your opinion! @david_stutz


Diederik P. Kingma, Shakir Mohamed, Danilo Jimenez Rezende, Max Welling. Semi-supervised Learning with Deep Generative Models. NIPS, 2014.

Kingma et al. propose a conditional variant of the variational auto-encoder for semi-supervised learning. A variatinal auto-encoder learns a latent representation $z$ corresponding to the data sample $x$. To this end, an inference network $q_psi(z|x)$ is used to approximate the posterior distribution $p(z|x)$. In the discussed case, $q_\psi(z|x) = \mathcal{N}(z|\mu_\psi(x), diag(\sigma^2_\psi(x)))$ where $\mu_\psi$ and $\sigma^2_\psi$ are modeled using a neural network. Following the variational principle, the lower bound on the marginal likelihood of a data point $x$ is given by

$\log p_\theta(x) \geq E_{q_\psi(z|x)}[log p_\theta(x|z)] – KL(q_psi(z|x)|p_\theta(z))$ (1)

where $p_\theta(z)$ is usually a unit Gaussian, such that the Kullback-Leibler divergence can be determined analytically. $p_\theta(x|z)$ is usually a multivariate Gaussian or a Bernoulli distribution whose parameters are computed by another neural network. This allows to easily evaluate $E_{q_\psi(z|x)}[log p_\theta(x|z)]$ as loss on top of the neural network.

In the semi-supervised case, i.e. if labels are only available for a small subset of the data, Kingma et al. Propose two different approaches. First, a variational auto-encoder as described above is learned on the labeled data and the inference network $q_\psi(z|x)$ is used as feature extractor. Or, second, a conditional variational auto-encoder is trained on all data. Then, there are two cases: either the label $y$ of a given samples $x$ is available or not. In both cases, the model is extended as follows. The generator network $p_\theta(x|y,z)$ now depends on both the latent space and the label. The inference model $q_\psi(z,y|x)$ decomposes into $q_\psi(z,y|x) = q_\psi(z|x) q_\psi(y|x)$ where $q_\psi(y|x)$ is specified as multinomial distribution:

$q_\psi(z|y,x) = \mathcal{N}(z|\mu_\psi(y,x), diag(\sigma^2_\psi(x)))$,

$q_\psi(y|x) = Cat(y|\pi_\psi(x))$.

Then, in the first case, the variational lower bound can be written as follows:

$\log p_\theta(x,y) \geq E_{q_\psi(z|x,y)}[\log p_\theta(x|y,z) - \log p_\theta(y) + \log p(z) - \log q_\psi(z|x,y)]$(2)

which follows from directly from Equation (1) by considering the new model. In the second case, the label is treated as missing latent variable:

$\log p_\theta(x) \geq E_{q_\psi(y,z|x)}[\log p_\theta(x|y,z) + \log p_\theta(y) + \log p(z) - \log q_\psi(y,z|x)] = (\ast)$.

Then, posterior inference can be performed. Let $\mathcal{L}$ be the negative of the right hand side of Equation (2), i.e. the lower bound if the label $y$ is known. It follows:

$(\ast) = \sum_y – q_\psi(y|x) \mathcal{L}(x, y) + \mathcal{H}(q_\psi(y|x))$(3)

where $\mathcal{H}$ denotes the entropy. This can be seen when noting that $q_\psi(y,z|x) = q_\psi(y|x)q_\psi(z|x)$ and substituting $\mathcal{L}(x,y)$. The variational bound for the whole dataset is the combination of Equations (2) and (3). However, $q_\psi(y|x)$ only takes part in Equation (3), thus, $q_\psi(y|x)$ is only learned on unlabeled data. This is undesirable as labeled data should also be used. Therefore, Kingma et al. Add a classification loss $E_{\widetilde{p}(x,y)}[-\log q_\psi(y|x)]$ where $\widetilde{p}(x,y)$ denotes the distribution of the labeled data. This classification loss can additionally be weighted – Kingma et al. Use $0.1 \cdot N$ as weight where $N$ is the total size of the dataset.

Training follows the discussion in [1], i.e. the training procedure for the general variational auto-encoder. Unfortunately, the gradients for the conditional variational auto-encoder are omitted. I would have expected a more detailed discussion regarding the newly introduced terms in Equations (2) and (3) for training.

What is your opinion on the summarized work? Or do you know related work that is of interest? Let me know your thoughts in the comments below or get in touch with me: