Skip to main content

Mixture Density Networks: Basics

Mixture Density Networks


I got interested in Mixture Density Network while reading Bishop's book on machine learning. His original paper can be found here.

It is useful in problems where inputs can map to multiple output values. This is where traditional discriminative neural networks fail.


The basic idea is that instead of using neural networks to learn a direct mapping from input variables, $\textbf{x}$ to target variables, $\textbf{t}$, one would use them to learn the parameters of a predefined distribution. In the case of the scenario given in the book, the distribution is a mixture of Gaussians. Concretely, it postulates that the distribution of a particular target sample $\textbf{t}_i$, $p(\textbf{t}_i|\textbf{x}_i)$ is given by

$$ p(\textbf{t}_i|\textbf{x}_i) = \sum^K_{k=1}\pi_{ik} \mathrm{N}(\textbf{t}_i|\mathbf{\mu}_{ik}(\textbf{x}),\sigma_{ik}^2(\textbf{x})),$$

where $\mathbf{\mu}_{ik} $ and $\sigma_{ik}^2 $ are the means and variances of the Gaussian mixture, respectively, and are functions of $\textbf{x}$.

Note here that implicit in the problem setup is that the component Gaussians are spherical. This might not pose much of a problem if $\textbf{t}_i$ where a scalar (i.e. the Gaussians are univariate). But one might want to relax this assumption if $\textbf{t}_i$ were indeed a vector consisting of non-independent variables.

The likelihood of the dataset, $\mathcal{D}$, is thus given by

$$p(\mathcal{D}) = \prod_i p(\textbf{t}_i|\textbf{x}_i).$$

We can then go ahead as per usual to maximise the dataset log-likelihood. To state explicitly, the negative log-likelihood (i.e. the error function that we minimize) is given by

$$\mathcal{E} = -\frac{1}{N}\sum_{i=1}^N\log\sum^K_{k=1}\pi_k \mathrm{N}(\textbf{t}_i|\mathbf{\mu}_{ik}(\textbf{x}),\sigma_{ik}^2(\textbf{x})).$$

Key appeal factors

It is important to note here that means and variances of the mixture components are functions of $\textbf{x}$ which means that a neural network can be used to find their values. Indeed this was the use case intended by Bishop. This is appealing to me for two reasons:

  1. as someone who has experience with neural networks, I can start applying it instead of dealing with the traditionaly method of using the EM algorithm

  2. it opens the opportunity to apply neural networks such as convolutional neural networks to problems where feature translation is important.

  3. learns the underlying distribution, which means we can sample from it


One would build a neural network and map the inputs to the parameters of the Gaussian mixtures. There would be $(L+2)K$ outputs to the neural network, where $L$ is the number of input variables, mapping $\textbf{x}$ to each of the parameters required for the $K$ Gaussians. There would be $L\times K$ outputs for the means, $K$ outputs for the mixture fractions and $K$ outputs for the standard deviations.

The values of $\mathbf{\mu}_{ik}$ would simply be the activations of the output layer. Just to be clear, the activations, $\textbf{a}$, refer to the weighted sum of the outputs of the previous layer, $\textbf{W}^T\textbf{z}+\textbf{b}$, where $\textbf{z}$ is the previous layer output and $\textbf{W}$ and $\textbf{b}$ are the weights and biases for the current layer, respectively.

The value of $\pi_{ik}$ is such that $\sum_k\pi_{ik}=1$. So a softmax function is applied to the relevant activations.

The values of $\sigma_{ik}$ are positive. So an exponential function is applied to the relevant activations

For mor details, I refer you to the actual paper.

Beginning with a toy problem

Credit where credit's due

As I searched online for material on mixture density networks, I am surprised that there has been a lot of tutorials already out online on mixture density networks. This post is heavily inspired/guided/dependent on material published in:

My aim here is to demonstrate/learn mixture density networks and then subsequently move on to some ideas I have in my mind for the application of such networks. Hence you would find quite a bit of replicated material as compared to the above references here.

Problem description

I tried to replicate the results in and Basically, I tried to learn the underlying distribution that would give the data points shown below. As can be seen, each x value maps to multiple y values.

Import Libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy import stats

%matplotlib inline
plt.rcParams["figure.figsize"] = (12,8)

sess = tf.InteractiveSession()

Generate Samples

The samples are generated according to the equation

$$ x = 7\sin(0.75y) + 0.5y + \epsilon,$$

where $\epsilon \sim \mathrm{N}(0,1)$.

In [2]:
N_samples = 1000

y = np.random.uniform(-10.5,10.5,N_samples)

x = (7*np.sin(0.75*y) + 0.5*y + np.random.normal(size=(N_samples)))

plt.plot(x, y, ".", alpha=0.5)
[<matplotlib.lines.Line2D at 0x1158f95f8>]

Define network

One cool thing to note here is the use of tf.reduce_logsumexp which basically implements the log-sum-exp trick. This prevents numerical instabilities.

In [3]:
x_ = tf.placeholder(tf.float32, [None, 1])
y_ = tf.placeholder(tf.float32, [None, 1])

hidden_size_1 = 24
n_components = 24
lam = 0.01

output_1 = tf.layers.dense(x_, 
                           activation = tf.sigmoid, 

final_output = tf.layers.dense(output_1,

means, sigma_act, fracs_act = tf.split(final_output, 3, axis=1)

sigma = tf.exp(sigma_act)

fracs = tf.nn.softmax(fracs_act)

diff = y_-means

squared_dist = -tf.square(diff/sigma)/2

loss = tf.reduce_mean(-tf.reduce_logsumexp(tf.log(fracs)+squared_dist-0.5*tf.log(np.pi*2*tf.square(sigma)), axis=1))

train_op = tf.train.AdamOptimizer().minimize(loss)

Train network

We can see that the training error basically saturates.

In [4]:
N_steps = 10000

train_loss = np.zeros((N_steps))

training_loss = []

for i in range(N_steps):[train_op], feed_dict= {x_:x.reshape((-1,1)),y_:y.reshape((-1,1))})
    if i%500==0:
        print("Training batch: {0}".format(i))
        training_loss.append(, feed_dict= {x_:x.reshape((-1,1)),y_:y.reshape((-1,1))}))
Training batch: 0
Training batch: 500
Training batch: 1000
Training batch: 1500
Training batch: 2000
Training batch: 2500
Training batch: 3000
Training batch: 3500
Training batch: 4000
Training batch: 4500
Training batch: 5000
Training batch: 5500
Training batch: 6000
Training batch: 6500
Training batch: 7000
Training batch: 7500
Training batch: 8000
Training batch: 8500
Training batch: 9000
Training batch: 9500
[<matplotlib.lines.Line2D at 0x117776eb8>]

Sampling for mixture density networks

Creating test inputs

In [5]:
x_test = np.linspace(-15,15,100)

test_mu, test_sigma, test_frac =[means, sigma, fracs], feed_dict = {x_:x_test.reshape((-1,1))})

Creating distribution heatmap

The heat map shown here shows that regions where data points exist have higher probabilities.

In [7]:
dist_heatmap = np.zeros((100,100))

for i in range(100):
    current_mu = test_mu[i,:]
    current_sigma = test_sigma[i,:]
    current_frac = test_frac[i,:]
    y_test = np.linspace(-10,10,100)
    current_dist = np.zeros((1,100))
    for j, (mu_, sigma_, frac_) in enumerate(zip(current_mu, current_sigma, current_frac)):
        temp = stats.norm.pdf(y_test, mu_, sigma_)*frac_
        current_dist = current_dist+temp
    dist_heatmap[:,i] = np.fliplr(current_dist.reshape((1,-1)))

f, ax = plt.subplots(figsize=(12,12))
[<matplotlib.text.Text at 0x1179846d8>, <matplotlib.text.Text at 0x1179977b8>]

Sample data points

Sampling actual data points from the learnt distribution again shows that the sampled data points are in the region as shown above.

In [10]:
samples_per_x = 20
samples = np.zeros((len(x_test)*samples_per_x,2))
for i, (mu_, sigma_, frac_) in enumerate(zip(test_mu, test_sigma, test_frac)):
    for j in range(samples_per_x):
        mixture_idx = np.random.choice(n_components, p=frac_)
        samples[i*samples_per_x+j,0] = x_test[i]
        samples[i*samples_per_x+j,1] = np.random.normal(loc = mu_[mixture_idx], scale = sigma_[mixture_idx])
In [11]:
plt.plot(samples[:,0], samples[:,1], ".")
[<matplotlib.lines.Line2D at 0x11963b978>]