跳转至

数据可视化库

Python 没有内置数据可视化库,但是第三方库比较丰富。本文就以 Matplotlib 为例。

Matplotlib

An object-oriented plotting library.

基本概念

安装方法

uv add matplotlib
pip install matplotlib

设计哲学

仍旧是面向对象。

在 Matplotlib 视角下,一张图由以下几个部分组成:

  • Figure(画布):最顶层的容器,控制整张图的大小、背景色和 DPI;
  • Axes(子图):真正的绘图区域。一个 Figure 可以包含多个 Axes;
  • Artist(元素):子图上所有东西,比如:坐标轴、线条、文字、多边形等都是 Artist 对象。
graph TD
    Figure[画布]

    Figure --> Axes1[子图 1]
    Figure --> Axes2[子图 2]
    Figure --> Axes3[子图 3]

    Axes1 --> Line[折线]
    Axes1 --> Scatter[散点]
    Axes1 --> Legend[图例]

    Axes2 --> Pai[饼图]
    Axes2 --> Patch[区域]

    Axes3 --> Text[文本]

操作画布

画布 (Figure) 是 Matplotlib 的 「根」对象,所有绘图行为最终都落在画布上。

那么很自然的,所有全局级的配置也应当在画布上,比如:

  • 尺寸(以英寸为单位);
  • 分辨率(DPI,即每英寸像素数);
  • 全局样式(字体、线宽、调色板)。

全局配置

Matplotlib 在导入时会加载一套默认样式,这些配置全部存放在全局样式表字典 matplotlib.rcParams 中,所有 Figure / Axes 在创建时,都会从 rcParams 中读取默认值:

import matplotlib as mpl

mpl.rcParams.update(
    {
        "figure.dpi": 120,  # 运行时 DPI
        "savefig.dpi": 120,  # 持久化 DPI
        "font.family": "SimHei",  # SimHei 支持中文,论文字体一般用 Times New Roman
        "axes.unicode_minus": False,  # 不编码减号
    }
)

初始化画布

你也许看到过直接用 plt.plot() 来绘图的代码,这是因为 Matplotlib 隐式创建了 Figure:

1
2
3
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 4), dpi=100)

但我并不推荐这种写法,因为所有的子图 add_subplot、元素 add_axes 等都需要自定义,容易出错。

我更推荐使用 plt.subplots 来初始化画布对象。例如:

fig, ax = plt.subplots(figsize=(8, 5), dpi=120)

这种方式方便我们提前创建好子图对象(上述的 ax),后续针对 ax 绘制元素即可。这种面向对象的方式更符合编码直觉。

持久化画布

使用 matplotlib.pyplotsavefig 方法即可。例如:

1
2
3
4
5
import matplotlib.pyplot as plt

...

plt.savefig("/path/to/image.svg")

操作子图

子图 (Axes) 是承载数据和坐标系的对象。

初始化单子图

import matplotlib as mpl
import matplotlib.pyplot as plt

# 全局配置
mpl.rcParams.update(
    {
        "figure.dpi": 80,
        "savefig.dpi": 80,
    }
)

# 创建画布和一个子图
fig, ax = plt.subplots(figsize=(4, 3))

# 绘制元素 - 折线图
ax.plot([1, 2, 3], [1, 4, 9])

# 绘制标题
dpi = fig.get_dpi()
w_inch, h_inch = fig.get_size_inches()
ax.set_title(f"Resolution: {dpi * w_inch:.0f} * {dpi * h_inch:.0f}")

# 持久化画布
plt.savefig("image.svg")

输出:

image.svg

关于 SVG 图像与分辨率

我标注分辨率是为了表示图像的大小,但请注意,SVG 没有分辨率一说。

初始化多子图

多子图原则

每个 Axes 应该是自解释的,不依赖上下文也能读懂。

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "figure.dpi": 80,
        "savefig.dpi": 80,
    }
)

# 创建画布和子图
fig, axes = plt.subplots(
    nrows=2,
    ncols=1,
    figsize=(4, 3),
    sharex=False,  # 不共享 X 轴
)

# 示例数据
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)

# 在第一个子图绘制折线图 - 正弦曲线
axes[0].plot(x, y1, color="blue")
axes[0].set_title("y = sin(x)")

# 在第二个子图绘制折线图 - 余弦曲线
axes[1].plot(x, y2, color="red")
axes[1].set_title("y = cos(x)")

# 持久化
plt.tight_layout()  # 优化布局
plt.savefig("image.svg")

其中 axes 是一个 \(2\times 1\)numpy.ndarray,每个元素都是一个子图实例。

sharex=False 的情况下进行输出:

image.svg

image.svg

plt.tight_layout() 的情况下进行输出:

image.svg

image.svg

操作元素

子图中一切可见对象,都是 Artist,即元素。每次调用元素创建方法,都会创建一个 Artist 对象并注册到 Axes 的内部列表,最终由 Figure 统一渲染。

常见元素类型有:坐标轴、折线、散点、图例、注释等。

绘制坐标轴

我一直称呼 Axes 为子图,但其实它的英文释义是轴。那既然是轴,就可以控制轴的各种属性,比如:范围、刻度、标签、比例等。

设置轴范围:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(0, 20, 200)
y = np.sin(x)

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(x, y)

# 限制显示范围
ax.set_xlim(0, 10)

plt.savefig("image.svg")

image.svg

image.svg

设置轴刻度:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(0, 20, 200)
y = np.sin(x)

fig, ax = plt.subplots()
ax.plot(x, y)

# 设置范围
ax.set_xlim(0, 10)

# 设置刻度
ax.set_xticks([i for i in range(11)])
ax.set_yticks([-1, -0.5, 0, 0.5, 1])

# 自定义标签(标签数量需要和刻度数量一致)
ax.set_yticklabels(["min", "-0.5", "zero", "0.5", "max"])

plt.savefig("image.svg")

image.svg

image.svg

设置轴标签:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(0, 20, 200)
y = np.sin(x)

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(x, y)

# 设置轴标签
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")

plt.tight_layout()
plt.savefig("image.svg")

image.svg

image.svg

设置轴比例:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(0, 20, 200)
y = 2 * x

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(x, y)

# 设置轴比例
ax.set_yscale("log")

plt.savefig("image.svg")

image.svg

image.svg

绘制折线

适合:趋势分析、连续变量。例如:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(-3, 3)
y = np.exp(-x)

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(
    x,
    y,
    linewidth=2,
    linestyle="--",
)

plt.savefig("image.svg")

输出:

image.svg

绘制散点

适合:离散样本、分布可视化、聚类结果。例如:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(-3, 3)
y = np.exp(-x)

fig, ax = plt.subplots(figsize=(4, 3))
ax.scatter(
    x,
    y,
    c="red",
    s=20,  # 散点大小
    alpha=0.3,  # 透明度
)

plt.savefig("image.svg")

输出:

image.svg

绘制图例

适合:提示每一个元素的信息。例如:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(-3, 3)
y_train = np.exp(-x)
y_test = x**2

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(
    x,
    y_train,
    c="red",
    label="train loss",
    alpha=0.3,
)
ax.plot(
    x,
    y_test,
    c="blue",
    label="test loss",
    alpha=0.3,
)

# 显示标签
plt.legend(
    loc="upper right",  # 显示位置,默认为 "best"
)

plt.savefig("image.svg")

输出:

image.svg

绘制注释

适合:强调关键点、解释异常值。例如:

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# 全局配置
mpl.rcParams.update(
    {
        "savefig.dpi": 80,
        "figure.dpi": 80,
    }
)

x = np.linspace(-3, 3)
y_train = np.exp(-x)
y_test = x**2

fig, ax = plt.subplots(figsize=(4, 3))

ax.set_ylim(-1, 3)
ax.plot(x, y_test, c="blue", label="test loss", alpha=0.3)

idx = np.argmin(y_test)
x_min = x[idx]
y_min = y_test[idx]

# 设置注释
ax.annotate(
    f"({x_min:.1f}, {y_min:.1f})",  # 注释文本
    xy=(x_min, y_min),  # 箭头终点
    xytext=(x_min, y_min + 2),  # 箭头起点(注释开始的地方)
    arrowprops=dict(
        arrowstyle="->",  # 箭头样式
        lw=1,  # 箭头粗细
    ),
    fontsize=9,
)

plt.legend()

plt.savefig("image.svg")

输出:

image.svg

沃兹基硕德

已经开始幻想自己画出很牛比的图放到论文里了 🤤。