About

In this tutorial, I will first teach you how to build a recurrent neural network (RNN) with a single layer, consisting of one single neuron, with PyTorch and Google Colab. I will also show you how to implement a simple RNN-based model for image classification.

This work is heavily inspired by Aurélien Géron's book called "Hand-On Machine Learning with Scikit-Learn and TensorFlow". Although his neural network implementations are purely in TensorFlow, I adopted/reused some notations/variables names and implemented things using PyTorch only. I really enjoyed his book and learned a lot from his explanations. His work inspired this tutorial and I strongly recommend the book.

We first import the necessary libraries we will use in the tutorial:

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np

RNN with A Single Neuron

The idea of this tutorial is to show you the basic operations necessary for building an RNN architecture using PyTorch. This guide assumes you have knowledge of basic RNNs and that you have read the tutorial on building neural networks from scratch using PyTorch. I will try to review RNNs wherever possible for those that need a refresher but I will keep it minimal.

First, let's build the computation graph for a single-layer RNN. Again, we are not concerned with the math for now, I just want to show you the PyTorch operations needed to build your RNN models.

For illustration purposes, this is the architecture we are building:

alt txt

And here is the code:

class SingleRNN(nn.Module):
    def __init__(self, n_inputs, n_neurons):
        super(SingleRNN, self).__init__()
        
        self.Wx = torch.randn(n_inputs, n_neurons) # 4 X 1
        self.Wy = torch.randn(n_neurons, n_neurons) # 1 X 1
        
        self.b = torch.zeros(1, n_neurons) # 1 X 4
        
    def forward(self, X0, X1):
        self.Y0 = torch.tanh(torch.mm(X0, self.Wx) + self.b) # 4 X 1
        
        self.Y1 = torch.tanh(torch.mm(self.Y0, self.Wy) +
                            torch.mm(X1, self.Wx) + self.b) # 4 X 1
        
        return self.Y0, self.Y1

In the above code, I have implemented a simple one layer, one neuron RNN. I initialized two weight matrices, Wx and Wy with values from a normal distribution. Wx contains connection weights for the inputs of the current time step, while Wy contains connection weights for the outputs of the previous time step. We added a bias b. The forward function computes two outputs -- one for each time step... two in this case. Note that we are using tanh as the nonlinearity (activation function).

As for the input, we are providing 4 instances, with each instance containing two input sequences.

For illustration purposes, this is how the data is being fed into the RNN model:

alt txt

And this is the code to test the model:

N_INPUT = 4
N_NEURONS = 1

X0_batch = torch.tensor([[0,1,2,0], [3,4,5,0], 
                         [6,7,8,0], [9,0,1,0]],
                        dtype = torch.float) #t=0 => 4 X 4

X1_batch = torch.tensor([[9,8,7,0], [0,0,0,0], 
                         [6,5,4,0], [3,2,1,0]],
                        dtype = torch.float) #t=1 => 4 X 4

model = SingleRNN(N_INPUT, N_NEURONS)

Y0_val, Y1_val = model(X0_batch, X1_batch)

After we have fed the input into the computation graph, we obtain outputs for each timestep (Y0, Y1), which we can now print as follows:

print(Y0_val)
print(Y1_val)
tensor([[-0.1643],
        [-0.9995],
        [-1.0000],
        [-1.0000]])
tensor([[-1.0000],
        [-0.6354],
        [-1.0000],
        [-0.9998]])

Increasing Neurons in RNN Layer

Next, I will show you how to generalize the RNN we have just build to let the single layer support an n amount of neurons. In terms of the architecture, nothing really changes since we have already parameterized the number of neurons in the computation graph we have built. However, the size of the output changes since we have changed the size of number of units (i.e., neurons) in the RNN layer. 

Here is an illustration of what we will build:

alt txt

And here is the code:

class BasicRNN(nn.Module):
    def __init__(self, n_inputs, n_neurons):
        super(BasicRNN, self).__init__()
        
        self.Wx = torch.randn(n_inputs, n_neurons) # n_inputs X n_neurons
        self.Wy = torch.randn(n_neurons, n_neurons) # n_neurons X n_neurons
        
        self.b = torch.zeros(1, n_neurons) # 1 X n_neurons
    
    def forward(self, X0, X1):
        self.Y0 = torch.tanh(torch.mm(X0, self.Wx) + self.b) # batch_size X n_neurons
        
        self.Y1 = torch.tanh(torch.mm(self.Y0, self.Wy) +
                            torch.mm(X1, self.Wx) + self.b) # batch_size X n_neurons
        
        return self.Y0, self.Y1
N_INPUT = 3 # number of features in input
N_NEURONS = 5 # number of units in layer

X0_batch = torch.tensor([[0,1,2], [3,4,5], 
                         [6,7,8], [9,0,1]],
                        dtype = torch.float) #t=0 => 4 X 3

X1_batch = torch.tensor([[9,8,7], [0,0,0], 
                         [6,5,4], [3,2,1]],
                        dtype = torch.float) #t=1 => 4 X 3

model = BasicRNN(N_INPUT, N_NEURONS)

Y0_val, Y1_val = model(X0_batch, X1_batch)

Now when we print the outputs produced for each time step, it is of size (4 X 5), which represents the batch size and number of neurons, respectively.

print(Y0_val)
print(Y1_val)
tensor([[ 0.9975, -0.9785,  0.9822, -0.8972,  0.9929],
        [ 0.9999, -0.9998,  1.0000, -0.9865,  0.9447],
        [ 1.0000, -1.0000,  1.0000, -0.9983,  0.6298],
        [-1.0000,  0.9915,  0.7409,  1.0000, -1.0000]])
tensor([[ 0.9858, -1.0000,  1.0000, -0.8826, -1.0000],
        [ 0.1480, -0.8635, -0.4498,  0.3516, -0.2848],
        [-0.1455, -0.9988,  1.0000,  0.0260, -0.9997],
        [-0.4084,  0.9973,  0.8858,  0.0783, -0.9993]])

PyTorch Built-in RNN Cell

If you take a closer look at the BasicRNN computation graph we have just built, it has a serious flaw. What if we wanted to build an architecture that supports extremely large inputs and outputs. The way it is currently built, it would require us to individually compute the outputs for every time step, increasing the lines of code needed to implement the desired computation graph. Below I will show you how to consolidate and implement this more efficiently and cleanly using the built-in RNNCell module.

Let's first try to implement this informally to analyze the role RNNCell plays:

rnn = nn.RNNCell(3, 5) # n_input X n_neurons

X_batch = torch.tensor([[[0,1,2], [3,4,5], 
                         [6,7,8], [9,0,1]],
                        [[9,8,7], [0,0,0], 
                         [6,5,4], [3,2,1]]
                       ], dtype = torch.float) # X0 and X1

hx = torch.randn(4, 5) # m X n_neurons
output = []

# for each time step
for i in range(2):
    hx = rnn(X_batch[i], hx)
    output.append(hx)

print(output)
[tensor([[ 0.2545,  0.7355,  0.3708, -0.6381,  0.0402],
        [-0.3379,  0.9996,  0.9976, -0.9769,  0.6668],
        [-0.9940,  1.0000,  1.0000, -0.9992,  0.4488],
        [-0.7486,  0.9925,  0.9862, -0.9642,  0.9990]], grad_fn=<TanhBackward>), tensor([[-0.9848,  1.0000,  1.0000, -0.9999,  0.9970],
        [ 0.2496, -0.7512,  0.1730, -0.3533, -0.7347],
        [-0.9502,  0.9998,  0.9995, -0.9966,  0.9119],
        [-0.6488,  0.7944,  0.9580, -0.9171,  0.2384]], grad_fn=<TanhBackward>)]

With the above code, we have basically implemented the same model that was implemented in BasicRNN. torch.RNNCell(...) does all the magic of creating and maintaining the necessary weights and biases for us. torch.RNNCell accepts a tensor as input and outputs the next hidden state for each element in the batch. Read more about this module here.

Now, let's formally build the computation graph using the same information we used above.

class CleanBasicRNN(nn.Module):
    def __init__(self, batch_size, n_inputs, n_neurons):
        super(CleanBasicRNN, self).__init__()
        
        self.rnn = nn.RNNCell(n_inputs, n_neurons)
        self.hx = torch.randn(batch_size, n_neurons) # initialize hidden state
        
    def forward(self, X):
        output = []

        # for each time step
        for i in range(2):
            self.hx = self.rnn(X[i], self.hx)
            output.append(self.hx)
        
        return output, self.hx
FIXED_BATCH_SIZE = 4 # our batch size is fixed for now
N_INPUT = 3
N_NEURONS = 5

X_batch = torch.tensor([[[0,1,2], [3,4,5], 
                         [6,7,8], [9,0,1]],
                        [[9,8,7], [0,0,0], 
                         [6,5,4], [3,2,1]]
                       ], dtype = torch.float) # X0 and X1


model = CleanBasicRNN(FIXED_BATCH_SIZE, N_INPUT, N_NEURONS)
output_val, states_val = model(X_batch)
print(output_val) # contains all output for all timesteps
print(states_val) # contain values for final state or final timestep, i.e., t=1
[tensor([[ 0.4582, -0.9106,  0.0743, -0.9608,  0.9272],
        [ 0.2087, -0.9999, -0.9486, -0.9969,  0.9996],
        [-0.2371, -1.0000, -0.7662, -1.0000,  1.0000],
        [-0.9576, -0.9306, -0.1201, -0.9781,  0.9277]], grad_fn=<TanhBackward>), tensor([[-0.9237, -1.0000, -0.9743, -1.0000,  1.0000],
        [-0.3181, -0.6270, -0.6122,  0.1921,  0.0647],
        [-0.7835, -0.9991, -0.9098, -0.9999,  0.9976],
        [-0.5765, -0.8469, -0.5469, -0.9785,  0.7512]], grad_fn=<TanhBackward>)]
tensor([[-0.9237, -1.0000, -0.9743, -1.0000,  1.0000],
        [-0.3181, -0.6270, -0.6122,  0.1921,  0.0647],
        [-0.7835, -0.9991, -0.9098, -0.9999,  0.9976],
        [-0.5765, -0.8469, -0.5469, -0.9785,  0.7512]], grad_fn=<TanhBackward>)

You can see how the code is much cleaner since we don't need to explicitly operate on the weights as shown in the previous code snippet  --  everything is handled implicitly and eloquently behind the scenes by PyTorch.

RNN for Image Classification

alt txt

Now that you have learned how to build a simple RNN from scratch and using the built-in RNNCell module provided in PyTorch, let's do something more sophisticated and special.

Let's try to build an image classifier using the MNIST dataset. The MNIST dataset consists of images that contain hand-written numbers from 1–10. Essentially, we want to build a classifier to predict the numbers displayed by a set of images. I know this sounds strange but you will be surprised by how well RNNs perform on this image classification task.

In addition, we will also be using the RNN module instead of the RNNCell module since we want to generalize the computation graph to be able to support an n number of layers as well. We will only use one layer in the following computation graph, but you can experiment with the code later on by adding more layers.

Importing the dataset

Before building the RNN-based computation graph, let's import the MNIST dataset, split it into test and train portions, do a few transformations, and further explore it. You will need the following PyTorch libraries and lines of code to download and import the MNIST dataset to Google Colab.

import torchvision
import torchvision.transforms as transforms
%%capture
BATCH_SIZE = 64

# list all transformations
transform = transforms.Compose(
    [transforms.ToTensor()])

# download and load training dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

# download and load testing dataset
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

The code above loads and prepares the dataset to be fed into the computation graph we will build later on. Take a few minutes to play around with the code and understand what is happening. Notice that we needed to provide a batch size. This is because trainloader and testloader are iterators which will make it easier when we are iterating on the dataset and training our RNN model with minibatches.

Exloring the dataset

Here is a few lines of code to explore the dataset. I won't cover much of what's going on here, but you can take some time and look at it by yourself.

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))

Model

Let's construct the computation graph. Below are the parameters:

# parameters 
N_STEPS = 28
N_INPUTS = 28
N_NEURONS = 150
N_OUTPUTS = 10
N_EPHOCS = 10

And finally, here is a figure of the RNN-based classification model we are building:

alt txt

And here is the code for the model:

class ImageRNN(nn.Module):
    def __init__(self, batch_size, n_steps, n_inputs, n_neurons, n_outputs):
        super(ImageRNN, self).__init__()
        
        self.n_neurons = n_neurons
        self.batch_size = batch_size
        self.n_steps = n_steps
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons) 
        
        self.FC = nn.Linear(self.n_neurons, self.n_outputs)
        
    def init_hidden(self,):
        # (num_layers, batch_size, n_neurons)
        return (torch.zeros(1, self.batch_size, self.n_neurons))
        
    def forward(self, X):
        # transforms X to dimensions: n_steps X batch_size X n_inputs
        X = X.permute(1, 0, 2) 
        
        self.batch_size = X.size(1)
        self.hidden = self.init_hidden()
        
        # lstm_out => n_steps, batch_size, n_neurons (hidden states for each time step)
        # self.hidden => 1, batch_size, n_neurons (final state from each lstm_out)
        lstm_out, self.hidden = self.basic_rnn(X, self.hidden)      
        out = self.FC(self.hidden)
        
        return out.view(-1, self.n_outputs) # batch_size X n_output

The ImageRNN model is doing the following:

  • The initialization function __init__(...) declares a few variables, and then a basic RNN layer basic_rnn followed by a fully-connected layer self.FC.
  • The init_hidden function initializes hidden weights with zero values. The forward function accepts an input of size n_steps X batch_size X n_neurons. Then the data flows through the RNN layer and then through the fully-connected layer.
  • The output are the log probabilities of the model.

Testing the model with some samples

A very good practice encouraged by PyTorch developers throughout their documentation, and which I really like and highly recommend, is to always test the model with a portion of the dataset before actual training. This is to ensure that you have the correct dimension specified and that the model is outputing the information you expect. Below I show an example of how to test your model:

dataiter = iter(trainloader)
images, labels = dataiter.next()
model = ImageRNN(BATCH_SIZE, N_STEPS, N_INPUTS, N_NEURONS, N_OUTPUTS)
logits = model(images.view(-1, 28,28))
print(logits[0:10])
tensor([[-0.0937, -0.0978, -0.0586,  0.0161,  0.0557,  0.0227, -0.0226, -0.0067,
          0.1092, -0.1295],
        [-0.0878, -0.0855, -0.0318,  0.0267,  0.0569,  0.0349, -0.0275,  0.0007,
          0.0999, -0.1215],
        [-0.0829, -0.1012, -0.0541,  0.0155,  0.0562,  0.0162, -0.0258, -0.0100,
          0.1077, -0.1310],
        [-0.1004, -0.0744, -0.0163,  0.0465,  0.0382,  0.0289, -0.0569,  0.0015,
          0.1003, -0.1266],
        [-0.0946, -0.0994, -0.0636,  0.0132,  0.0539,  0.0236, -0.0221, -0.0034,
          0.1013, -0.1298],
        [-0.0922, -0.0974, -0.0334,  0.0369,  0.0622,  0.0378, -0.0497,  0.0005,
          0.0983, -0.1160],
        [-0.0834, -0.0942, -0.0414,  0.0258,  0.0573,  0.0174, -0.0218, -0.0105,
          0.1045, -0.1307],
        [-0.0782, -0.0985, -0.0458,  0.0154,  0.0579,  0.0214, -0.0227, -0.0060,
          0.1035, -0.1269],
        [-0.1019, -0.0963, -0.0549,  0.0214,  0.0551,  0.0203, -0.0167, -0.0048,
          0.1131, -0.1316],
        [-0.1078, -0.1001, -0.0372,  0.0187,  0.0682,  0.0412, -0.0265, -0.0021,
          0.1033, -0.1191]], grad_fn=<SliceBackward>)

Training

Now let's look at the code for training the image classification model. But first, let's declare a few helper functions needed to train the model:

import torch.optim as optim

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model instance
model = ImageRNN(BATCH_SIZE, N_STEPS, N_INPUTS, N_NEURONS, N_OUTPUTS)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def get_accuracy(logit, target, batch_size):
    ''' Obtain accuracy for training round '''
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

Before training a model in PyTorch, you can programatically specify what device you want to use during training; the torch.device(...) function tells the program that we want to use the GPU if one is available, otherwise the CPU will be the default device.

Then we create an instance of the model, ImageRNN(...)``, with the proper parameters. The criterion represents the function we will use to compute the loss of the model. Thenn.CrossEntropyLoss()` function basically applies a log softmax followed by a negative log likelihood loss operation over the output of the model. To compute the loss, the function needs both the log probabilities and targets. We will see later in our code how to provide this to the criterion.

For training, we also need an optimization algorithm which helps to update weights based on the current loss. This is achieved with the optim.Adam optimization function, which requires the model parameters and a learning rate. Alternatively, you can also use optim.SGD or any other optimization algorithm that's available.

The get_accuracy(...) function simply computes the accuracy of the model given the log probabilities and target values. As an exercise, you can write code to test this function as we did with the model before.

Let's put everything together and train our image classification model:

for epoch in range(N_EPHOCS):  # loop over the dataset multiple times
    train_running_loss = 0.0
    train_acc = 0.0
    model.train()
    
    # TRAINING ROUND
    for i, data in enumerate(trainloader):
         # zero the parameter gradients
        optimizer.zero_grad()
        
        # reset hidden states
        model.hidden = model.init_hidden() 
        
        # get the inputs
        inputs, labels = data
        inputs = inputs.view(-1, 28,28) 

        # forward + backward + optimize
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_running_loss += loss.detach().item()
        train_acc += get_accuracy(outputs, labels, BATCH_SIZE)
         
    model.eval()
    print('Epoch:  %d | Loss: %.4f | Train Accuracy: %.2f' 
          %(epoch, train_running_loss / i, train_acc/i))
Epoch:  0 | Loss: 0.7489 | Train Accuracy: 75.88
Epoch:  1 | Loss: 0.3113 | Train Accuracy: 91.05
Epoch:  2 | Loss: 0.2325 | Train Accuracy: 93.33
Epoch:  3 | Loss: 0.1957 | Train Accuracy: 94.53
Epoch:  4 | Loss: 0.1706 | Train Accuracy: 95.21
Epoch:  5 | Loss: 0.1564 | Train Accuracy: 95.58
Epoch:  6 | Loss: 0.1471 | Train Accuracy: 95.91
Epoch:  7 | Loss: 0.1329 | Train Accuracy: 96.14
Epoch:  8 | Loss: 0.1283 | Train Accuracy: 96.42
Epoch:  9 | Loss: 0.1196 | Train Accuracy: 96.65

We can also compute accuracy on the testing dataset to test how well the model performs on the image classification task. As you can see below, our RNN model is performing very well on the MNIST classification task.

test_acc = 0.0
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs = inputs.view(-1, 28, 28)

    outputs = model(inputs)

    test_acc += get_accuracy(outputs, labels, BATCH_SIZE)
        
print('Test Accuracy: %.2f'%( test_acc/i))
Test Accuracy: 95.83

Final Words

Please notice that we are not using GPU in this tutorial since the models we are building are relatively simple. As an exercise, you can take a look at the PyTorch documentation to learn how to program specific operations to execute on the GPU. You can then try to optimize the code to run on the GPU. If you need help with this, reach out to me on Twitter.

That's it for this tutorial. Congratulations! You are now able to implement a basic RNN in PyTorch. You also learned how to apply RNNs to solve a real-world, image classification problem.

In the next tutorial, we will do more advanced things with RNNs and try to solve even more complex problems, such as sarcasm detection and sentiment classification. Until next time!