# The beauty of the Walrus Operator in Python

The *assignment expresisons* or the *Walrus* operator in Python was introduced with the PEP 572 proposal.
It serves to assign variables within an expression. After studying the characteristics of the *Walrus* operator, it’s impossible not to love this
operator. Let’s take a look!

# What

The *Walrus* operator `:=`

allows us to evaluate expressions and save their result in a variable. All in one line using the notation `var := expr`

.
Yes, it is an eponymous of the walrus animal because the operator resembles the tusks of a walrus.

The variable `var`

can be then reused. In a way, we can see it as an operator for naming the expressions.

# Why

The *Walrus* operator is is making the *conditional* statements and the *list comprehensions* much simpler. We can be concise
and make the code neat. We can write everything in one line and reuse the assigned variable later on.

# How

There are multiple uses of the *assignment expression* operator. You can take a look at this thread on
StackOverflow. One of the most
usefull applications is to share a subexpression between a comprehension filter clause and its output:

```
[y for x in data if (y := f(x)) is not None]
```

One very interesting case to use the *Walrus* operator is to calculate higher-order derivatives of a given function. For this purpose we
will use JAX, an updated version of Autograd for automatic differentiation.

With the *Walrus* operator we will calculate and plot the first 4 derivatives of the `tanh`

function while maining the expressions simple.

First we import all the required modules and generate evenly spaced numbers on the x-axis:

1
2
3
4
5

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, vmap
x = jnp.linspace(-7, 7, 500)

Then, we can simply calculate and plot all the derivatives of the `tanh`

function as demonstrated:

1
2
3
4
5
6
7
8
9
10
11

fig = plt.figure(figsize=(15, 9))
plt.plot(
x, (dx := jnp.tanh)(x), # original function
x, vmap(dx := grad(dx))(x), # first derivative
x, vmap(dx := grad(dx))(x), # second derivative
x, vmap(dx := grad(dx))(x), # third derivative
x, vmap(dx := grad(dx))(x), # fourth derivative
)
plt.show()

To achieve the same effect we have to nest and re-compute all over again all derivatives along. This is shown in code snippet below:

1
2
3
4
5
6
7
8
9
10
11

fig = plt.figure(figsize=(15, 9))
plt.plot(
x, jnp.tanh(x), # original function
x, vmap(grad(jnp.tanh))(x), # first derivative
x, vmap(grad(grad(jnp.tanh)))(x), # second derivative
x, vmap(grad(grad(grad(jnp.tanh))))(x), # third derivative
x, vmap(grad(grad(grad(grad(jnp.tanh)))))(x), # fourth derivative
)
plt.show()

The resulting plot is the one depicted below:

The source code for this work can be found in this **Jupyter Notebook**.
If this is something you like and would like to see similar content you could follow me on **LinkedIn**
or **Twitter**. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.

## Leave a comment