Can ChatGPT differentiate and code it in Python? It seems so!
In the previous post we tested ChatGPT’s abilities to turn LaTeX equations into Python code. This time we go step forward and ask ChatGPT to calculate a differential, explain all derivation steps and implement it in Python.
What
Differentiation is an important operation in mathematics. The calculation of the first-order differential might include several operations. These operations are not necessarily obvious and it might be challenging to complete the process.
The differentiation process can be seen as a text generation task using a Large Language Model (LLM). Instead of generating ordinary words, we generate mathematical expressions.
To this end, ChatGPT is having excellent math skills. Following this, we test ChatGPT’s capabilities to calculate a differential and explain all steps. Finally, we demand ChatGPT to provide a Python implementation. To verify the correctness, we compare the results by using the JAX library for automatic differentiation.
Why
Math skills of the Large Language Models are one of their crucial assets. Testing the LLMs like ChatGPT on math puzzles can reveal their limitations. We can also learn how to properly use ChatGPT for our math tasks and even use it as a teaching resource.
How
Let’s consider the following function:
\[f(x) = \dfrac{x^2cos(x) - x}{10}\]We ask ChatGPT to provide a detailed explanation on how to calculate the first-order differential. We use the following prompt:
Can you calculate and show all the steps of the derivative of the following
equation: 1/10 * (x**2 * cos(x) - x)
upon which we get a detailed response of the differentiation process. This conversation is shown below:
Using some of the online derivative calculators like the one from Symbolab we can easily verify that the answer is the same. This is already amazing, but let’s go one step further and ask for a Python implementation using NumPy:
Let’s verify the veracity using the automatic differentiation from JAX. First we define the implementation by ChatGPT:
1
2
3
4
5
6
7
import numpy as np
def f_chatgpt(x):
return 1/10 * (x**2 * np.cos(x) - x)
def df_chatgpt(x):
return 1/10 * (2*x*np.cos(x) - x**2*np.sin(x) - 1)
Then we implement the same using JAX. Here we only need the original function, as the differentiation is completely numeric.
1
2
3
4
5
6
7
8
import jax.numpy as jnp
from jax import grad, jit, vmap
def f_jax(x):
return 1/10 * (x**2 * jnp.cos(x) - x)
jax_grad_f = jit(grad(f_jax))
jax_elementwise_grad = jit(vmap(jax_grad_f))
After this we generate equidistant points on the x-axis. Using the ChatGPT and JAX implementation we calculate the derivative of the function in every point:
1
2
3
4
5
6
7
8
x_start, x_stop = -10, 10
num_points = 1000
np_x = np.linspace(x_start, x_stop, num_points, dtype=np.float32)
jax_x = jnp.linspace(x_start, x_stop, num_points, dtype=jnp.float32)
fp_chatgpt = df_chatgpt(np_x)
fp_jax = jax_elementwise_grad(jax_x)
Finally, we compare the results only to conclude they are close enough to be considered equal:
1
np.allclose(fp_chatgpt, fp_jax, atol=1.e-5) # prints True
The source code for this work can be found in this Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.
Leave a comment