Beyond MNIST


Tensorflow, Convnets, Pipelines

Not Another MNIST example

MNIST might be regarded as the "Hello, World!" of neural nets. As such, we hurry past it in the quest to solve problem we really care about. However, it's value in practice may be less to impress and more in the ability to mod out data-related issues as we debug pipelines. This is something I've come to appreciate lately as I focus more on vision problems.

Deep Learning platforms make it fairly straight forward to efficiently train neural nets. Tensorflow tutorials guide you through the mechanics of setting up your first multi-layer perceptron for handwritten digit classification run on your GPU. But the library has much more to offer. And so I decided to explore the API, progressively building up more functionality/generality from this basis. Throughout the implementation of what follows, the MNIST performance benchmarks were a reliable beacon for debugging.

The first thing we get to take for granted in using MNIST: no problems getting & preprocessing the data, no problems with storage or memory. In fact, even a modest laptop can load it on-the-fly with 2 lines from the tutorial

Playing with the resulting object using IPython's tab completion, we find several useful attributes and methods for setting up a train/validation/test split and loading the data into convenient arrays. The tutorial offers a crash-course in the necessary linear algebra to implement a multi-layer perceptron using the lower level api. While this overview is instructive, it becomes tedious and redundant to spell out the matrix equations in deeper neural net architectures. This point resonates in the followup tutorial where convolutional layers are introduced.

This "for experts" tutorial also requires some consideration in specifying dimensions as the convolution and pooling layers are generally used to funnel the spatial dimensions of input images into deeper representations with many feature extractors. Again, instructive but tedious/error-prone. Fortunately, TF provides a 'layers' api that obviates the need to create your own helper functions to get shapes.

You get a lot of mileage with the introduction of convolutional layers. It is from this architecture that I will push the limits of performance by implementing sensible defaults as I drive the code to generalization beyond MNIST.

More Logging

The first step is to introduce additional instrumentation to view training progress. Tensorboard uses 'summary operations' specified during graph construction to serialize statistics of interest and FileWriters to write this to disk.

# Your ops defined above.
tf.summary.scalar('loss', xentropy)
tf.summary.scalar('accuracy', accuracy)
# Other scalar summaries...
summary = tf.summary.merge_all()
# Start a session
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    summary_writer = tf.summary.FileWriter(
    # training & periodic evaluation
    sum_res =, 

After training begins, you can start up tensorboard on the (non-default) port 9009 with:

tensorboard --logdir=/path/to/writer/ 

This makes it easy to track training progress. You can evaluate loss or accuracy for training batches as well as your validation set, writing to subfolders. Tensorboard will plot these performance statistics over training steps for both samples so that you can diagnose under/over fitting.

Next, I want to save the model. This is accomplished with a saver.

# Your ops defined above.
saver = tf.train.Saver()
# Start a session
with tf.Session() as sess:
    # training & periodic evaluation, 

Visualizing Learned Features

On a similar vein, I'd like to visualize activations under sample data to gain intuition for the learned features. I found this post helpful and have adapted the notebook example for a class method:

def plot_activations(self, sess, 
    layer_name, inpt):
    units =, 
    np.reshape(inpt, [1, self.inpt_dim]), 
    self.keep_prob: 1.0, 
    self.is_training: False}) 
    n_filters = units.shape[3]
    plt.figure(1, figsize=(30, 30))
    n_cols = 6 
    n_rows = np.ceil(n_filters 
                     / n_cols)+1 
    for ii in range(n_filters):
        plt.title('Filter ' + str(ii))
        plt.imshow(units[:, :, 0, ii], 

Organizing Your Search

I also expect to make many runs, so tracking parameter configuration vs. performance can get unwieldy. To remedy this, I simply append a summary of key parameters and performance statistics to a log at the end of the training run.

Now that we can monitor important training statistics and save/visualize our models, I want to focus on generally useful techniques to improve performance. There are various activation functions that can be used, which may lead to faster convergence. We need only experiment with the options, though ReLU might be regarded as a good default.

Better Initializations

Different parameter initialization schemes have been found empirically to improve convergence by resisting the tendency toward exploding/vanishing gradients. To pass from the tutorial's truncated_normal initialization strategy to another, we import the variance_scaling_initializer from tf.contrib.layers (as he_init for concision):

conv_layer = tf.layers.conv2d(
fc_layer = tf.layers.dense(

Control gradients with Batch Normalization

Initialization strategies like the one above are helpful early in training. However, as training progresses, the challenge of exploding/vanishing gradients can resurface. The technique of batch normalization can be employed to mitigate this risk, though this comes at the expense of increasing the computational overhead. To implement batch normalization in tensorflow, we have:

# While declaring placeholders:
is_training = tf.placeholder(tf.bool)
fc_layer = tf.layers.dense(...)
fc_layer = tf.layers.batch_normalization(
           fc_layer, training=is_training)
# Next, apply activations
# Note there is some dispute 
# whether to apply batch_norm 
# after the activation
# Before starting a session...
extra_update_ops = tf.get_collection(
with tf.Session() as sess:
    #train steps[train_op, 

Control overfitting with Regularization

So far, we have instrumentation with summaries, visualization with tensorboard and custom methods, model persistence with saver, and we've passed from the low-level graph construction to use of layers where we could take advantage of alternative initialization schemes, activation functions, and implement batch normalization. Next, I want to consider regularization. There are many techniques designed to improve the ability of a deep learning model to generalize well to data not used in training. The tutorial's use of dropout layers as well as extensive parameter sharing through the use of convolutional layers takes us a long way. I wanted to also use early stopping. I found this example useful and have something like:

best_val_loss = float('Inf')
patience = 2 
patience_cnt = 0 
min_delta = 0.005
# Start training loop...
if val_loss < best_val_loss:, ckpt_file, 
    best_val_loss = val_loss
    patience_cnt = 0
elif (val_loss - best_val_loss) 
                  > min_delta:
    patience_cnt += 1
# at the level of epoch iteration:
if patience_cnt > patience:

Exploit Symmetries with Data Augmentation

With this test, I can limit the risk of overfitting a model to the training data by terminating training steps when the validation loss doesn't continue to reduce appreciably. The min_delta and patience parameters will depend on the problem at hand and can be chosen after a few runs. Another important regularization strategy is called data augmentation. It is often desirable in image classification that small translations or rotations should not affect label assignment. This geometric intuition is part of the power in using convolutional and pooling layers but it can be assisted by generating appropriate perturbations of the training data to help in generalization. Here, I create a helper function:

import numpy as np
from np.random import randint, permutation
from scipy import ndimage

def data_augmenter(xs, height, width):
    med_val = np.median(xs)
    x_image = np.reshape(xs, 
              (height, width))
    trans_op =
         permutation([0, 1, 2]) 
    for op in trans_op:
        if op == 0:
            return x_image.flatten()
        elif op == 1:
            angle = 
            randint(-10, 10, 1)
            x_image = 
            angle, reshape=False, 
            shift = randint(-2,2,2)
            x_image = ndimage.shift(
            x_image, shift, 

This function applies a sequence of transformations that may include the identity, a small rotation, and a small translation. Calling this function on samples during training time limits the risk of overfitting through memorizing the training data.

At this point, we have introduced a number of reputed techniques for improving model convergence and monitoring progress. In an effort to making this code more generally useful, I want to move away from a top level script and introduce a class to construct the graph similar to this. I will also use argparse to pass command line arguments to make the code more flexible.

I also wrote a class in a file to read images from a specific directory structure to provide attributes/methods similar to the Dataset class used in the MNIST tutorial. Finally, because the general image corpus will not uniformly contain 28X28 or 299X299 images, I wrote an image preprocessing function that resizes while maintaining the aspect ratio and returns a number of images cropped to required dimensions by translating and cropping along the longer side after rescaling. The full code for this post can be found here.

Taking the Module Further

There are many other directions we can take. For instance, introducing a config file, distributing the computation across available devices, writing an evaluation script, reading from TFRecords, offering additional flexibility in choosing activation functions, optimizers, initialization strategies. TF-Slim is a great source of inspiration in this regard. I hope this overview has helped "experts" bridge the gap toward more flexible convnets.