Make every line of your Python code to count with these 4 profilers

3 minute read

A chain is as strong as the weakest link. Making analogy between chain links and lines of code, the same is true for a program. The most demanding line of code might make the entire code utterly inefficient. We have to find out which one using a set of code profiling tools.

What

In Python, there is a well developed ecosystem of profilers. The profilers can help us to check what makes a program slow and give us hints how to improve it. We will see how to use the following four profilers to inspect every line of a Python program:

  • snakeviz: browser based graphical viewer of the profiling results
  • line_profiler: profiling the code line by line
  • memory_profiler: profiling the memory consumption of every line
  • pyinstrument: call stack profiler focusing on the slowest calls.

Why

The code inefficiency can be burried down into the layers of programming abstractions. For this reason it might be difficult to spot it.

Nevertheless, code efficiency is always a desired property. Sometimes it could be of paramount importance and we need to strive for it.

How

Let’s say we want to profile the scaled dot product attention implemented as:

1
2
3
4
5
6
7
8
9
10
import math
import torch

def attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    return torch.matmul(p_attn, value), p_attn

To profile it we randomly generate the input tensors query, key and value:

1
2
3
4
5
6
7
8
9
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process
input_seq_length = 100  # Maximum length of the input sequence

query = torch.rand((batch_size, input_seq_length, d_k))
key = torch.rand((batch_size, input_seq_length, d_k))
value = torch.rand((batch_size, input_seq_length, d_v))
mask = torch.FloatTensor(input_seq_length, input_seq_length).uniform_() > 0.8

Then in a Jupyter Notebook, we can profile every aspect of the function with this snippet:

1
2
3
4
5
6
7
8
9
10
11
%load_ext snakeviz
%load_ext line_profiler
%load_ext memory_profiler
%load_ext pyinstrument

%snakeviz -t attention(query, key, value, mask)
%lprun -T lprof_attention.txt -f attention attention(query, key, value, mask)
%mprun -T mprof_attention.txt -f attention attention(query, key, value, mask)

%%pyinstrument --timeline=True
_ = attention(query, key, value, mask)

The snakeviz profiler will give us beautiful visualization of the call stack. This is demonstrated in the picture below:

Icicle diagram showing rectangles of the call stack
Figure 1: Icicle diagram showing the call stack


The line_profiler and memory_profiler will show us the details for every line of code. They will output a detailed textual report as shown below:

Total time: 0.014188 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
============================================================================================================================
     6                                           def attention(query, key, value, mask=None):
     7         1      18000.0  18000.0      0.1      d_k = query.size(-1)
     8         1    6250000.0 6250000.0     44.1      scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     9         1       1000.0   1000.0      0.0      if mask is not None:
    10         1    1070000.0 1070000.0      7.5          scores = scores.masked_fill(mask == 0, -1e9)
    11         1    4547000.0 4547000.0     32.0      p_attn = scores.softmax(dim=-1)
    12         1    2302000.0 2302000.0     16.2      return torch.matmul(p_attn, value), p_attn
Line #    Mem usage    Increment  Occurrences   Line Contents
============================================================================================================================
     6    189.9 MiB    189.9 MiB           1   def attention(query, key, value, mask=None):
     7    189.9 MiB      0.0 MiB           1       d_k = query.size(-1)
     8    189.9 MiB      0.0 MiB           1       scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     9    189.9 MiB      0.0 MiB           1       if mask is not None:
    10    190.0 MiB      0.1 MiB           1           scores = scores.masked_fill(mask == 0, -1e9)
    11    190.0 MiB      0.0 MiB           1       p_attn = scores.softmax(dim=-1)
    12    190.0 MiB      0.0 MiB           1       return torch.matmul(p_attn, value), p_attn

Finally, pyinstrument will render the following trace:

PyInstrument call trace stack
Figure 2: PyInstrument Trace


The source code for this work can be found in this Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.


Leave a comment