Make every line of your Python code to count with these 4 profilers
A chain is as strong as the weakest link. Making analogy between chain links and lines of code, the same is true for a program. The most demanding line of code might make the entire code utterly inefficient. We have to find out which one using a set of code profiling tools.
What
In Python, there is a well developed ecosystem of profilers. The profilers can help us to check what makes a program slow and give us hints how to improve it. We will see how to use the following four profilers to inspect every line of a Python program:
- snakeviz: browser based graphical viewer of the profiling results
- line_profiler: profiling the code line by line
- memory_profiler: profiling the memory consumption of every line
- pyinstrument: call stack profiler focusing on the slowest calls.
Why
The code inefficiency can be burried down into the layers of programming abstractions. For this reason it might be difficult to spot it.
Nevertheless, code efficiency is always a desired property. Sometimes it could be of paramount importance and we need to strive for it.
How
Letβs say we want to profile the scaled dot product attention implemented as:
1
2
3
4
5
6
7
8
9
10
import math
import torch
def attention(query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = scores.softmax(dim=-1)
return torch.matmul(p_attn, value), p_attn
To profile it we randomly generate the input tensors query
, key
and value
:
1
2
3
4
5
6
7
8
9
d_k = 64 # Dimensionality of the linearly projected queries and keys
d_v = 64 # Dimensionality of the linearly projected values
batch_size = 64 # Batch size from the training process
input_seq_length = 100 # Maximum length of the input sequence
query = torch.rand((batch_size, input_seq_length, d_k))
key = torch.rand((batch_size, input_seq_length, d_k))
value = torch.rand((batch_size, input_seq_length, d_v))
mask = torch.FloatTensor(input_seq_length, input_seq_length).uniform_() > 0.8
Then in a Jupyter Notebook, we can profile every aspect of the function with this snippet:
1
2
3
4
5
6
7
8
9
10
11
%load_ext snakeviz
%load_ext line_profiler
%load_ext memory_profiler
%load_ext pyinstrument
%snakeviz -t attention(query, key, value, mask)
%lprun -T lprof_attention.txt -f attention attention(query, key, value, mask)
%mprun -T mprof_attention.txt -f attention attention(query, key, value, mask)
%%pyinstrument --timeline=True
_ = attention(query, key, value, mask)
The snakeviz profiler will give us beautiful visualization of the call stack. This is demonstrated in the picture below:
The line_profiler and memory_profiler will show us the details for every line of code. They will output a detailed textual report as shown below:
Total time: 0.014188 s
Line # Hits Time Per Hit % Time Line Contents
============================================================================================================================
6 def attention(query, key, value, mask=None):
7 1 18000.0 18000.0 0.1 d_k = query.size(-1)
8 1 6250000.0 6250000.0 44.1 scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
9 1 1000.0 1000.0 0.0 if mask is not None:
10 1 1070000.0 1070000.0 7.5 scores = scores.masked_fill(mask == 0, -1e9)
11 1 4547000.0 4547000.0 32.0 p_attn = scores.softmax(dim=-1)
12 1 2302000.0 2302000.0 16.2 return torch.matmul(p_attn, value), p_attn
Line # Mem usage Increment Occurrences Line Contents
============================================================================================================================
6 189.9 MiB 189.9 MiB 1 def attention(query, key, value, mask=None):
7 189.9 MiB 0.0 MiB 1 d_k = query.size(-1)
8 189.9 MiB 0.0 MiB 1 scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
9 189.9 MiB 0.0 MiB 1 if mask is not None:
10 190.0 MiB 0.1 MiB 1 scores = scores.masked_fill(mask == 0, -1e9)
11 190.0 MiB 0.0 MiB 1 p_attn = scores.softmax(dim=-1)
12 190.0 MiB 0.0 MiB 1 return torch.matmul(p_attn, value), p_attn
Finally, pyinstrument will render the following trace:
The source code for this work can be found in this Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.
Leave a comment