Batch vs Layer Normalization in Deep Neural Nets. The Illustrated Way!

3 minute read

The Batch Normalization (BN) and Layer Normalization (LN) techniques are widely used techniques in deep learning. They ease the optimization process and help very deep networks converge faster.

The Batch Normalization (BN) has been successfully applied to the vision tasks while the the Layer Normalization (LN) to the sequential tasks, mainly in NLP.

They are both normalization techniques applied to the input of each layer. Therefore, both techniques calculate the same two statistics: mean and variance, only in a different manner.

To fully understand and know the difference between BN and LN is not quite straightforward. For this reason in this blog we explain batch and layer normalization with intuitive illustrations.

Batch Normalization

The Batch Normalization (BN) was first introduced to solve the internal covariance shift i.e. the change in the distributions of the hidden layers in the course of training.

In general BN accelerates the training of deep neural nets. It also reduces the dependence of gradients on the scale of the parameters (or of their initial values) which in turn allows the use of much higher learning rates. However, it has one drawback, it requires a sufficiently large batch size.

To save us the pain of reading the entire paper, without going too much into the details, the essential part on how Batch Normalization works is illustrated in the image below:

Illustrated Batch Normalization
Illustrated Batch Normalization


In Batch Normalization the mean and variance are calculated for each individual channel across all elements (pixels or tokens) in all batches.

Even though at first sight it may sound counterintuitive, but because it iterates over all batches it is called Batch Normalization

Layer Normalization

Having sufficiently large batch size is impractical for sequential tasks where the length of the sequence can be very large. To mitigate this constraint, the Layer Normalization (LN) technique was introduced.

Thus, LN is less dependent on the batch size and can be used with small batch sizes. It can also help to reduce the vanishing gradient in recurrent neural networks.

Agian, to save us the the time of reading the entire paper the essential part on how Layer Normalization works is illustrated in the image below:

Illustrated Layer Normalization
Illustrated Layer Normalization


In Batch Normalization the mean and variance are calculated for each individual batch across all elements (pixels or tokens) in all channels.

At first sight it may be counterintuitive, but because it iterates over all channels i.e. features it is called Layer Normalization

PyTorch Implementation

The PyTorch implementation is given in code snippets below. During ttraining, we create two learnable parameters gamma and beta to shift the normalized input.

To have unbiased inference, during training we calculate the moving mean and moving variance. Later on, during inference we use these moving averages as a replacement of the test data mean and variance.

1
2
import torch
import torch.nn as nn

Below you can find the Batch Normalization implementation in PyTorch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class BatchNorm(nn.Module):
    def __init__(self, num_features: int, training: bool, eps: float=1e-6) -> None:
        super().__init__()
        self.training = training

        # learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # hyperparams
        self.eps = eps
        self.moving_mean = nn.Parameter(torch.zeros(num_features), requires_grad=False)
        self.moving_var = nn.Parameter(torch.ones(num_features), requires_grad=False)
        
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0, keepdim=True)
            var = x.var(dim=0, keepdim=True)

            self.moving_mean = 0.9 * self.moving_mean + 0.1 * mean
            self.moving_var = 0.9 * self.moving_var + 0.1 * var
        else:
            mean = self.moving_mean
            var = self.moving_var
        
        x = (x - mean) / torch.sqrt(var + self.eps)
        x = self.gamma * x + self.beta
        return x

Below you can find the Layer Normalization implementation in PyTorch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class LayerNorm(nn.Module):
    def __init__(self, num_features: int, training: bool, eps: float=1e-6) -> None:
        super().__init__()
        self.training = training

        # learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # hyperparams
        self.eps = eps
        self.moving_mean = nn.Parameter(torch.zeros(num_features), requires_grad=False)
        self.moving_var = nn.Parameter(torch.ones(num_features), requires_grad=False)
        
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=-1, keepdim=True)
            var = x.var(dim=-1, keepdim=True)

            self.moving_mean = 0.9 * self.moving_mean + 0.1 * mean
            self.moving_var = 0.9 * self.moving_var + 0.1 * var
        else:
            mean = self.moving_mean
            var = self.moving_var
        
        x = (x - mean) / torch.sqrt(var + self.eps)
        x = self.gamma * x + self.beta
        return x

Take a look and downlaod the PDF document containing the illustrations above by clicking on the button below:

Downlaod Illustrations

For more information, please follow me on LinkedIn or Twitter. If you like this content you can subscribe to the mailing list below to get similar updates from time to time.


Leave a comment