The beauty of the Walrus Operator in Python

2 minute read

The assignment expresisons or the Walrus operator in Python was introduced with the PEP 572 proposal. It serves to assign variables within an expression. After studying the characteristics of the Walrus operator, it’s impossible not to love this operator. Let’s take a look!

What

The Walrus operator := allows us to evaluate expressions and save their result in a variable. All in one line using the notation var := expr. Yes, it is an eponymous of the walrus animal because the operator resembles the tusks of a walrus.

The variable var can be then reused. In a way, we can see it as an operator for naming the expressions.

Why

The Walrus operator is is making the conditional statements and the list comprehensions much simpler. We can be concise and make the code neat. We can write everything in one line and reuse the assigned variable later on.

How

There are multiple uses of the assignment expression operator. You can take a look at this thread on StackOverflow. One of the most usefull applications is to share a subexpression between a comprehension filter clause and its output:

[y for x in data if (y := f(x)) is not None]

One very interesting case to use the Walrus operator is to calculate higher-order derivatives of a given function. For this purpose we will use JAX, an updated version of Autograd for automatic differentiation.

With the Walrus operator we will calculate and plot the first 4 derivatives of the tanh function while maining the expressions simple.

First we import all the required modules and generate evenly spaced numbers on the x-axis:

1
2
3
4
5
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import grad, vmap

x = jnp.linspace(-7, 7, 500)

Then, we can simply calculate and plot all the derivatives of the tanh function as demonstrated:

1
2
3
4
5
6
7
8
9
10
11
fig = plt.figure(figsize=(15, 9))

plt.plot(
    x, (dx := jnp.tanh)(x),      # original function
    x, vmap(dx := grad(dx))(x),  # first derivative
    x, vmap(dx := grad(dx))(x),  # second derivative
    x, vmap(dx := grad(dx))(x),  # third derivative
    x, vmap(dx := grad(dx))(x),  # fourth derivative
)

plt.show()

To achieve the same effect we have to nest and re-compute all over again all derivatives along. This is shown in code snippet below:

1
2
3
4
5
6
7
8
9
10
11
fig = plt.figure(figsize=(15, 9))

plt.plot(
    x, jnp.tanh(x),                                # original function
    x, vmap(grad(jnp.tanh))(x),                    # first derivative
    x, vmap(grad(grad(jnp.tanh)))(x),              # second derivative
    x, vmap(grad(grad(grad(jnp.tanh))))(x),        # third derivative
    x, vmap(grad(grad(grad(grad(jnp.tanh)))))(x),  # fourth derivative
)

plt.show()

The resulting plot is the one depicted below:

tanh function and its first four derivatives
Figure 1: tanh (in blue) and its first 4 derivatives


The source code for this work can be found in this Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.


Leave a comment