# Can ChatGPT differentiate and code it in Python? It seems so!

In the previous post we tested ChatGPT’s abilities to turn LaTeX equations into Python code. This time we go step forward and ask ChatGPT to calculate a differential, explain all derivation steps and implement it in Python.

# What

Differentiation is an important operation in mathematics. The calculation of the first-order differential might include several operations. These operations are not necessarily obvious and it might be challenging to complete the process.

The differentiation process can be seen as a text generation task using a Large Language Model (LLM). Instead of generating ordinary words, we generate mathematical expressions.

To this end, *ChatGPT* is having excellent math skills. Following this, we test ChatGPT’s capabilities to calculate a differential and explain all steps. Finally, we demand *ChatGPT*
to provide a Python implementation. To verify the correctness, we compare the results by using the JAX library for automatic
differentiation.

# Why

Math skills of the Large Language Models are one of their crucial assets. Testing the LLMs like *ChatGPT* on math puzzles can reveal
their limitations. We can also learn how to properly use *ChatGPT* for our math tasks and even use it as a teaching resource.

# How

Let’s consider the following function:

\[f(x) = \dfrac{x^2cos(x) - x}{10}\]We ask *ChatGPT* to provide a detailed explanation on how to calculate the first-order differential. We use the following prompt:

```
Can you calculate and show all the steps of the derivative of the following
equation: 1/10 * (x**2 * cos(x) - x)
```

upon which we get a detailed response of the differentiation process. This conversation is shown below:

Using some of the online derivative calculators like the one from Symbolab we can easily verify that the answer is the same. This is already amazing, but let’s go one step further and ask for a Python implementation using NumPy:

Let’s verify the veracity using the automatic differentiation from JAX. First we define the implementation by *ChatGPT*:

1
2
3
4
5
6
7

import numpy as np
def f_chatgpt(x):
return 1/10 * (x**2 * np.cos(x) - x)
def df_chatgpt(x):
return 1/10 * (2*x*np.cos(x) - x**2*np.sin(x) - 1)

Then we implement the same using *JAX*. Here we only need the original function, as the differentiation is
completely numeric.

1
2
3
4
5
6
7
8

import jax.numpy as jnp
from jax import grad, jit, vmap
def f_jax(x):
return 1/10 * (x**2 * jnp.cos(x) - x)
jax_grad_f = jit(grad(f_jax))
jax_elementwise_grad = jit(vmap(jax_grad_f))

After this we generate equidistant points on the *x-axis*. Using the *ChatGPT* and *JAX* implementation we calculate the derivative
of the function in every point:

1
2
3
4
5
6
7
8

x_start, x_stop = -10, 10
num_points = 1000
np_x = np.linspace(x_start, x_stop, num_points, dtype=np.float32)
jax_x = jnp.linspace(x_start, x_stop, num_points, dtype=jnp.float32)
fp_chatgpt = df_chatgpt(np_x)
fp_jax = jax_elementwise_grad(jax_x)

Finally, we compare the results only to conclude they are close enough to be considered equal:

1

np.allclose(fp_chatgpt, fp_jax, atol=1.e-5) # prints True

The source code for this work can be found in this
Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on **LinkedIn**
or **Twitter**. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.

## Leave a comment