Understand high-dimensional data with three lines of code
It is difficult to gain insights from a multivariate data. It is also challenging to plot a multivariate data for better understanding. Even so there are techniques to visualize high-dimensional data. We can use these techniques easily via the Pandas plotting API.
What
The Pandas plotting API contains many visualization techniques. This helps us in understanding the data stored in a dataframe. When we have multivariate data, the following three plots are quite handy:
Why
Understanding correlations between multidimensional data points is important. An equaly important task is to detect outliers. All this can give us a hint on how the space is divided and whether the features are discriminative enough.
How
It is very easy to produce the plots using the Pandas plotting API. We will use the Iris dataset to demonstrate this.
1
2
3
4
5
6
7
8
9
10
11
12
13
import matplotlib.pyplot as plt
import pandas as pd
data_url="https://raw.githubusercontent.com/pandas-dev/pandas/main/pandas/tests/io/data/csv/iris.csv"
df_iris = pd.read_csv(data_url)
plt.figure(figsize=(12, 7))
pd.plotting.andrews_curves(df_iris, class_column="Name")
plt.figure(figsize=(12, 7))
pd.plotting.parallel_coordinates(df_iris, class_column="Name")
pd.plotting.scatter_matrix(df_iris.drop("Name", axis=1), figsize=(8, 8), alpha=0.7)
The resulting plots are shown below:
The source code for this work can be found in this Jupyter Notebook. If this is something you like and would like to see similar content you could follow me on LinkedIn or Twitter. Additionally, you can subscribe to the mailing list below to get similar updates from time to time.
Leave a comment