Mixture Density Networks: Basics
Mixture Density Networks¶
Background¶
I got interested in Mixture Density Network while reading Bishop's book on machine learning. His original paper can be found here.
It is useful in problems where inputs can map to multiple output values. This is where traditional discriminative neural networks fail.
Theory¶
The basic idea is that instead of using neural networks to learn a direct mapping from input variables, $\textbf{x}$ to target variables, $\textbf{t}$, one would use them to learn the parameters of a predefined distribution. In the case of the scenario given in the book, the distribution is a mixture of Gaussians. Concretely, it postulates that the distribution of a particular target sample $\textbf{t}_i$, $p(\textbf{t}_i\textbf{x}_i)$ is given by
$$ p(\textbf{t}_i\textbf{x}_i) = \sum^K_{k=1}\pi_{ik} \mathrm{N}(\textbf{t}_i\mathbf{\mu}_{ik}(\textbf{x}),\sigma_{ik}^2(\textbf{x})),$$where $\mathbf{\mu}_{ik} $ and $\sigma_{ik}^2 $ are the means and variances of the Gaussian mixture, respectively, and are functions of $\textbf{x}$.
Note here that implicit in the problem setup is that the component Gaussians are spherical. This might not pose much of a problem if $\textbf{t}_i$ where a scalar (i.e. the Gaussians are univariate). But one might want to relax this assumption if $\textbf{t}_i$ were indeed a vector consisting of nonindependent variables.
The likelihood of the dataset, $\mathcal{D}$, is thus given by
$$p(\mathcal{D}) = \prod_i p(\textbf{t}_i\textbf{x}_i).$$We can then go ahead as per usual to maximise the dataset loglikelihood. To state explicitly, the negative loglikelihood (i.e. the error function that we minimize) is given by
$$\mathcal{E} = \frac{1}{N}\sum_{i=1}^N\log\sum^K_{k=1}\pi_k \mathrm{N}(\textbf{t}_i\mathbf{\mu}_{ik}(\textbf{x}),\sigma_{ik}^2(\textbf{x})).$$Key appeal factors¶
It is important to note here that means and variances of the mixture components are functions of $\textbf{x}$ which means that a neural network can be used to find their values. Indeed this was the use case intended by Bishop. This is appealing to me for two reasons:

as someone who has experience with neural networks, I can start applying it instead of dealing with the traditionaly method of using the EM algorithm

it opens the opportunity to apply neural networks such as convolutional neural networks to problems where feature translation is important.

learns the underlying distribution, which means we can sample from it
Implementation¶
One would build a neural network and map the inputs to the parameters of the Gaussian mixtures. There would be $(L+2)K$ outputs to the neural network, where $L$ is the number of input variables, mapping $\textbf{x}$ to each of the parameters required for the $K$ Gaussians. There would be $L\times K$ outputs for the means, $K$ outputs for the mixture fractions and $K$ outputs for the standard deviations.
The values of $\mathbf{\mu}_{ik}$ would simply be the activations of the output layer. Just to be clear, the activations, $\textbf{a}$, refer to the weighted sum of the outputs of the previous layer, $\textbf{W}^T\textbf{z}+\textbf{b}$, where $\textbf{z}$ is the previous layer output and $\textbf{W}$ and $\textbf{b}$ are the weights and biases for the current layer, respectively.
The value of $\pi_{ik}$ is such that $\sum_k\pi_{ik}=1$. So a softmax function is applied to the relevant activations.
The values of $\sigma_{ik}$ are positive. So an exponential function is applied to the relevant activations
For mor details, I refer you to the actual paper.
Beginning with a toy problem¶
Credit where credit's due¶
As I searched online for material on mixture density networks, I am surprised that there has been a lot of tutorials already out online on mixture density networks. This post is heavily inspired/guided/dependent on material published in:
 a blog post on blog.otoro.net
 a blog post on edwardlib.org
 a blog post by Christopher Bonnett
My aim here is to demonstrate/learn mixture density networks and then subsequently move on to some ideas I have in my mind for the application of such networks. Hence you would find quite a bit of replicated material as compared to the above references here.
Problem description¶
I tried to replicate the results in blog.otoro.net and edwardlib.org. Basically, I tried to learn the underlying distribution that would give the data points shown below. As can be seen, each x value maps to multiple y values.
Import Libraries¶
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy import stats
%matplotlib inline
plt.rcParams["figure.figsize"] = (12,8)
sess = tf.InteractiveSession()
Generate Samples¶
The samples are generated according to the equation
$$ x = 7\sin(0.75y) + 0.5y + \epsilon,$$where $\epsilon \sim \mathrm{N}(0,1)$.
N_samples = 1000
y = np.random.uniform(10.5,10.5,N_samples)
x = (7*np.sin(0.75*y) + 0.5*y + np.random.normal(size=(N_samples)))
plt.plot(x, y, ".", alpha=0.5)
Define network¶
One cool thing to note here is the use of tf.reduce_logsumexp
which basically implements the logsumexp trick. This prevents numerical instabilities.
x_ = tf.placeholder(tf.float32, [None, 1])
y_ = tf.placeholder(tf.float32, [None, 1])
hidden_size_1 = 24
n_components = 24
lam = 0.01
output_1 = tf.layers.dense(x_,
hidden_size_1,
activation = tf.sigmoid,
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
bias_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=lam)
)
final_output = tf.layers.dense(output_1,
3*n_components,
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
bias_initializer=tf.contrib.layers.xavier_initializer(uniform=False),
kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=lam)
)
means, sigma_act, fracs_act = tf.split(final_output, 3, axis=1)
sigma = tf.exp(sigma_act)
fracs = tf.nn.softmax(fracs_act)
diff = y_means
squared_dist = tf.square(diff/sigma)/2
loss = tf.reduce_mean(tf.reduce_logsumexp(tf.log(fracs)+squared_dist0.5*tf.log(np.pi*2*tf.square(sigma)), axis=1))
train_op = tf.train.AdamOptimizer().minimize(loss)
sess.run(tf.global_variables_initializer())
Train network¶
We can see that the training error basically saturates.
N_steps = 10000
train_loss = np.zeros((N_steps))
training_loss = []
for i in range(N_steps):
sess.run([train_op], feed_dict= {x_:x.reshape((1,1)),y_:y.reshape((1,1))})
if i%500==0:
print("Training batch: {0}".format(i))
training_loss.append(sess.run(loss, feed_dict= {x_:x.reshape((1,1)),y_:y.reshape((1,1))}))
plt.plot(500*np.arange(len(training_loss)),training_loss)
x_test = np.linspace(15,15,100)
test_mu, test_sigma, test_frac = sess.run([means, sigma, fracs], feed_dict = {x_:x_test.reshape((1,1))})
Creating distribution heatmap¶
The heat map shown here shows that regions where data points exist have higher probabilities.
dist_heatmap = np.zeros((100,100))
for i in range(100):
current_mu = test_mu[i,:]
current_sigma = test_sigma[i,:]
current_frac = test_frac[i,:]
y_test = np.linspace(10,10,100)
current_dist = np.zeros((1,100))
for j, (mu_, sigma_, frac_) in enumerate(zip(current_mu, current_sigma, current_frac)):
temp = stats.norm.pdf(y_test, mu_, sigma_)*frac_
current_dist = current_dist+temp
dist_heatmap[:,i] = np.fliplr(current_dist.reshape((1,1)))
f, ax = plt.subplots(figsize=(12,12))
ax.imshow(dist_heatmap)
ax.set_xticks([0,99])
ax.set_xticklabels([15,15])
ax.set_yticks([0,99])
ax.set_yticklabels([10,10])
Sample data points¶
Sampling actual data points from the learnt distribution again shows that the sampled data points are in the region as shown above.
samples_per_x = 20
samples = np.zeros((len(x_test)*samples_per_x,2))
for i, (mu_, sigma_, frac_) in enumerate(zip(test_mu, test_sigma, test_frac)):
for j in range(samples_per_x):
mixture_idx = np.random.choice(n_components, p=frac_)
samples[i*samples_per_x+j,0] = x_test[i]
samples[i*samples_per_x+j,1] = np.random.normal(loc = mu_[mixture_idx], scale = sigma_[mixture_idx])
plt.plot(samples[:,0], samples[:,1], ".")