Building an Image Classifier in Python

A step-by-step implementation of a Dogs vs Cats classifier in Keras

Almost everything around you today is empowered by Machine Learning and Deep Learning. From the way Google Photos recognizes your face from a bunch of people in an image, to autonomous vehicles, these algorithms have changed the landscape of machine-centred thinking.

Among the most recent (and exciting) contributions is in the field of Computer Vision where insights are derived from media files (images and video). The most noted part is Image Classification — the ability for a machine to distinguish, or classify, between objects when given an image.

In this article, we’ll be building an image classifier to distinguish cats from dogs (and vice-versa) with sizable accuracy.

Picking up the dataset

In this article, we’ll be using a filtered version of the Dogs vs Cats dataset from Kaggle. You can download it from here.

We’ll be using Tensorflow’s high-level API, Keras, for this lesson. So, go ahead and import the necessary packages.

On lines 13–14, we are downloading the filtered dataset and unzipping it to a folder `cats_and_dogs_filtered``

The dataset has the following structure

|__ train
|______ cats: [cat.0.jpg, cat.1.jpg, cat.2.jpg ....]
|______ dogs: [dog.0.jpg, dog.1.jpg, dog.2.jpg ...]
|__ validation
|______ cats: [cat.2000.jpg, cat.2001.jpg, cat.2002.jpg ....]
|______ dogs: [dog.2000.jpg, dog.2001.jpg, dog.2002.jpg ...]

Next, we’re going to assign variables for the training and validation set:

To visualize the number of classes within the training and validation set, go ahead and use the following code:

Image for post
Image for post

Preprocessing the Data

The images we use will have to be appropriately preprocessed before feeding them into our model.

We’re going to instantiate the ImageDataGenerator class to generate a batch of images that will be fed into the model. We are rescaling our pixel value to the range [0,1] to help the network converge faster.

We will now extract a batch of images from the train_data_gen generator to visualize our training set and then plot 5 of them using matplotlib.pyplot

Image for post
Image for post
5 Images from the training set plotted in Matplotlib

Defining the Model Architecture

Compiling the Model

We’ll use the ADAM optimizer and the Binary Cross Entropy as the loss

Image for post
Image for post
Our Model’s Architecture so far

Training the model

Here’s where the magic happens! We’re going to actually train the model by feeding our input images to it. For this, we’ll use

We will allow the model to train over 15 epochs. This is relatively small, but remember, we have a small dataset (~3000 images).

Image for post
Image for post
Training progress of your model

Visualizing the Training Results

We’re now going to plot a graph that will allow us to see the accuracy and loss as our model trained.

Image for post
Image for post
Training and Validation Accuracy (left); Training and Validation Loss (right)

It’s quite clear that we’ve achieved a test accuracy of ~85% and a validation accuracy of ~70%. Our model is clearly overfitting!

Preventing Overfitting

Overfitting usually happens when your model has learnt to recognize patterns specific to the training set only. As a result, when tested on the validation data, we didn’t get very good results.

To mitigate this issue, we will employ Data Augmentation and Dropout.

Data augmentation is synthesizing more data from already existing data. This can include (in our case) flipping, rotation and zoom.

We can achieve this by use of the following code:

Image for post
Image for post
Visualizing Data Augmentation on a random image from our training set

Similarly, we can create a validation data generator. Remember, we never apply augmentation on the validation set.

In addition to Data Augmentation, we’ll be using Dropout, a regularization technique, that commonly helps to spread out the weights of the network.

We’ve added a 20% dropout twice in the model architecture:

When you train the model, you’ll notice that there was significantly less overfitting that before.

Plotting the graphs, we get:

Image for post
Image for post
We’ve obtained an accuracy of ~65%

That’s it! You’ve built an image classifier that distinguishes cats from dogs with an overall accuracy of 65%. It’s not the best but given the limited size of the dataset and model architecture we’ve used, it’s pretty good.

Thanks for reading! Let me know what you think in the comments below!

Written by

I train ConvNets. Currently building Caer, a Computer Vision library in Python.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store