Can ChatGPT turn LaTeX equations into Python code?

1 minute read

ChatGPT seems to have good skills to turn equations into code. This opens the following question: can we use ChatGPT to turn the complicated LaTeX equations into Python code?

What

Given an equation written in LaTeX, ChatGPT can convert it in Python. We test this premise with a simple exercise.

Why

It is quite a powerfull concept. We can turn the theory from the publications into code and make it actionable.

How

To quickly test the ChatGPT abilities to turn LaTeX equations into code we can set up a simple “experiment”.

On the one hand we take the multi-headed attention code from the well-known guide “The Annotated Transformer”:

1
2
3
4
5
6
7
8
9
10
import math
import torch

def attention_reference(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    return torch.matmul(p_attn, value)

On the other hand we prompt ChatGPT with the following sentence:

Can you write a Python code using PyTorch for the following equation 
written in LaTeX: a = softmax(\frac{QK^T}{\sqrt{d_k}})V

after which the following code snippet is generated:

1
2
3
4
5
6
7
import torch

def attention_chatgpt(Q, K, V, d_k):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(d_k)
    scores = torch.softmax(scores, dim=-1)
    attention = torch.matmul(scores, V)
    return attention

For convenience, the function was renamed to attention_chatgpt (originally it was named softmax_attention).

For testing purposes we randomly generate the input tensors Q, K and V:

1
2
3
4
Q = torch.randn(64, 8, 100, 64)
K = torch.randn(64, 8, 100, 64)
V = torch.randn(64, 8, 100, 64)
d_k = torch.tensor(64).float()

Finally we compare the output tensors from both implementations to conclude they are close enough to be considered equal:

1
2
3
res_chatgpt = attention_chatgpt(Q, K, V, d_k)
res_reference = attention_reference(Q, K, V)
torch.allclose(res_chatgpt, res_reference)  # prints `True`

If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.


Leave a comment